/
OS-World3a4b673
"""
Tools module for GUI agents.
This module provides various tools for GUI agents to perform tasks such as web search,
context fusion, subtask planning, trajectory reflection, memory retrieval, grounding,
evaluation, and action generation.
"""
import os
import json
import base64
import requests
import time
from typing import Dict, Any, Optional, List, Union, Tuple
from abc import ABC, abstractmethod
import logging
from ..core.mllm import LLMAgent, WebSearchAgent, EmbeddingAgent
import threading
from ..prompts import get_prompt, module
logger = logging.getLogger("desktopenv.tools")
class BaseTool(ABC):
"""Base class for all tools."""
_prompts_dict = None
_prompts_dict_lock = threading.Lock()
# Directory retained for backward compatibility; no longer scanned directly
_prompts_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "prompts")
@classmethod
def _load_prompts_dict(cls):
# Deprecated: kept for compatibility if other code accesses _prompts_dict.
# Now pull prompts via the registry to avoid direct filesystem coupling.
if cls._prompts_dict is None:
with cls._prompts_dict_lock:
if cls._prompts_dict is None:
cls._prompts_dict = {}
def __init__(self, provider: str, model_name: str, tool_name: str):
"""
Initialize the base tool.
Args:
provider: API provider name (e.g., "gemini", "openai")
model_name: Model name to use (e.g., "gemini-2.5-pro")
tool_name: Name of the tool (used as key in prompts files)
"""
self.provider = provider
self.model_name = model_name
self.tool_name = tool_name
self._load_prompts_dict()
self._prompt_template = self._get_prompt_template()
# Create LLMAgent instance for tool usage
self.engine_params = {
"engine_type": provider,
"model": model_name
}
self.llm_agent = LLMAgent(engine_params=self.engine_params, system_prompt=self._prompt_template)
def _get_prompt_template(self) -> str:
if self.tool_name is None:
return ""
# Prefer reading prompt text directly from gui_agents.prompts.module
try:
prompt_category_map = {
# manager prompts
"query_formulator": ("manager", "query_formulator"),
"narrative_summarization": ("manager", "narrative_summarization"),
"context_fusion": ("manager", "context_fusion"),
"planner_role": ("manager", "planner_role"),
"supplement_role": ("manager", "supplement_role"),
"dag_translator": ("manager", "dag_translator"),
"objective_alignment": ("manager", "objective_alignment"),
# worker prompts
"operator_role": ("worker", "operator_role"),
"technician_role": ("worker", "technician_role"),
"analyst_role": ("worker", "analyst_role"),
"grounding": ("worker", "grounding"),
"text_span": ("worker", "text_span"),
"episode_summarization": ("worker", "episode_summarization"),
# evaluator prompts
"worker_success_role": ("evaluator", "worker_success_role"),
"worker_stale_role": ("evaluator", "worker_stale_role"),
"periodic_role": ("evaluator", "periodic_role"),
"final_check_role": ("evaluator", "final_check_role"),
}
# Tools that should be prefixed with system architecture info
tools_require_system_prefix = {
"planner_role",
"supplement_role",
"dag_translator",
"operator_role",
"technician_role",
"analyst_role",
"worker_success_role",
"worker_stale_role",
"periodic_role",
"final_check_role",
"objective_alignment",
}
category_tuple = prompt_category_map.get(self.tool_name)
prompt_text = ""
if category_tuple is None:
# Try root-level attribute on module (e.g., system_architecture)
if hasattr(module, self.tool_name):
prompt_text = getattr(module, self.tool_name)
else:
return ""
else:
category_name, key_name = category_tuple
category_obj = getattr(module, category_name, None)
if category_obj is None:
return ""
value = getattr(category_obj, key_name, None)
if isinstance(value, str) and value:
prompt_text = value
else:
return ""
# Optionally prefix with system architecture information for selected tools
if (
isinstance(prompt_text, str)
and prompt_text
and self.tool_name in tools_require_system_prefix
):
system_info = getattr(module, "system_architecture", "")
if isinstance(system_info, str) and system_info:
return f"{system_info}\n\n{prompt_text}"
return prompt_text
except Exception:
# Fallback to registry to allow central overrides if available
return ""
def _call_lmm(self, input_data: Dict[str, Any], temperature: float = 0.0):
"""
Call the LMM model for inference using the prompt template with retry mechanism
Args:
input_data: Dictionary containing input data to format the prompt template
temperature: Temperature parameter to control randomness of output
Returns:
Model response as text
"""
# self.llm_agent.reset()
# Extract text and image inputs
text_input = input_data.get('str_input', '')
image_input = input_data.get('img_input', None)
# Add the message with the formatted prompt
self.llm_agent.reset()
self.llm_agent.add_message(text_input, image_content=image_input, role="user")
# Implement safe retry mechanism
max_retries = 3
attempt = 0
content, total_tokens, cost_string = "", [0, 0, 0], ""
while attempt < max_retries:
try:
content, total_tokens, cost_string = self.llm_agent.get_response(temperature=temperature)
break # If successful, break out of the loop
except Exception as e:
attempt += 1
logger.error(f"LLM call attempt {attempt} failed: {str(e)}")
if attempt == max_retries:
logger.error("Max retries reached. Returning error message.")
return f"Error: LLM call failed after {max_retries} attempts: {str(e)}", [0, 0, 0], ""
time.sleep(1.0)
return content, total_tokens, cost_string
@abstractmethod
def execute(self, tool_input: Dict[str, Any]) -> Tuple[str, List[int], str]:
"""
Execute the tool with the given input.
Args:
tool_input: Dictionary containing the input for the tool
Expected to have 'str_input' and/or 'img_input' keys
Returns:
The output of the tool as a string
"""
pass
class ToolFactory:
"""Factory class for creating tools."""
@staticmethod
def create_tool(tool_name: str, provider: str, model_name: str, **kwargs) -> 'BaseTool':
"""
Create a tool instance based on the tool name.
Args:
tool_name: Name of the tool to create
provider: API provider name
model_name: Model name to use
**kwargs: Additional parameters to pass to the tool
Returns:
An instance of the specified tool
Raises:
ValueError: If the tool name is not recognized
"""
tool_map = {
"embedding": (EmbeddingTool, None), # all
"query_formulator": (QueryFormulatorTool, "query_formulator"), # manager
"websearch": (WebSearchTool, None), # manager
"narrative_summarization": (NarrativeSummarizationTool, "narrative_summarization"), # manager
"context_fusion": (ContextFusionTool, "context_fusion"), # manager
"planner_role": (SubtaskPlannerTool, "planner_role"), # manager
"supplement_role": (SubtaskPlannerTool, "supplement_role"), # manager
"dag_translator": (DAGTranslatorTool, "dag_translator"), # manager
"objective_alignment": (ObjectiveAlignmentTool, "objective_alignment"), # manager
"operator_role": (ActionGeneratorTool, "operator_role"), # worker
"technician_role": (ActionGeneratorTool, "technician_role"), # worker
"analyst_role": (ActionGeneratorTool, "analyst_role"), # worker
"grounding": (GroundingTool, "grounding"), # worker
"text_span": (TextSpanTool, "text_span"), # worker
"episode_summarization": (EpisodeSummarizationTool, "episode_summarization"), # worker
"worker_success_role": (EvaluatorTool, "worker_success_role"), # evaluator
"worker_stale_role": (EvaluatorTool, "worker_stale_role"), # evaluator
"periodic_role": (EvaluatorTool, "periodic_role"), # evaluator
"final_check_role": (EvaluatorTool, "final_check_role"), # evaluator
}
if tool_name not in tool_map:
raise ValueError(f"Unknown tool name: {tool_name}")
tool_class, prompt_key = tool_map[tool_name]
# WebSearchTool and EmbeddingTool don't need a prompt
if tool_name == "websearch":
return tool_class(provider, model_name, None, **kwargs)
if tool_name == "embedding":
return tool_class(provider, model_name, None, **kwargs)
return tool_class(provider, model_name, prompt_key, **kwargs)
class WebSearchTool(BaseTool):
"""Tool for performing web searches."""
def __init__(self, provider: str, model_name: str, tool_name: str):
"""
Initialize the web search tool.
Args:
provider: API provider name (e.g., "bocha", "exa")
model_name: Model name to use (not used for WebSearchAgent)
tool_name: Name of the tool (used as key in prompts.json)
"""
self.provider = provider
# Create WebSearchAgent instance for search
self.engine_params = {
"engine_type": provider,
"model": model_name,
}
# Initialize WebSearchAgent
self.search_agent = WebSearchAgent(engine_params=self.engine_params)
def execute(self, tool_input: Dict[str, Any]) -> Tuple[str, List[int], str]:
"""
Execute a web search with the given query.
Args:
tool_input: Dictionary containing the search query
Expected to have 'str_input' key with the search query
Returns:
Search results as a string
"""
query = tool_input.get('str_input', '')
if not query:
return "Error: No search query provided", [0, 0, 0], ""
try:
# Get the answer from the search results
answer, total_tokens, cost = self.search_agent.get_answer(query)
# Return just the answer
return answer, total_tokens, cost # type: ignore
except Exception as e:
logger.error(f"Error during web search: {str(e)}")
return f"Error: Web search failed: {str(e)}", [0, 0, 0], ""
class ContextFusionTool(BaseTool):
"""Tool for fusing multiple contexts together."""
def execute(self, tool_input: Dict[str, Any]):
"""
Fuse multiple contexts together.
Args:
tool_input: Dictionary containing the contexts to fuse
Expected to have 'str_input' key with JSON-formatted contexts
Returns:
Fused context as a string
"""
contexts = tool_input.get('str_input', '')
if not contexts:
return "Error: No contexts provided"
# Use the prompt template and LMM for context fusion
return self._call_lmm(tool_input)
class SubtaskPlannerTool(BaseTool):
"""Tool for planning subtasks."""
def execute(self, tool_input: Dict[str, Any]):
"""
Plan subtasks for a given task.
Args:
tool_input: Dictionary containing the task description
Expected to have 'str_input' key with the task description
May also have 'img_input' key with a screenshot
Returns:
Subtask plan as a string
"""
task = tool_input.get('str_input', '')
if not task:
return "Error: No task description provided"
# Use the prompt template and LMM for subtask planning
return self._call_lmm(tool_input)
class NarrativeSummarizationTool(BaseTool):
"""Tool for summarizing narrative memories."""
def execute(self, tool_input: Dict[str, Any]):
"""
Summarize narrative memories.
Args:
tool_input: Dictionary containing the narrative memory data
Expected to have 'str_input' key with the narrative memory data
May also have 'img_input' key with relevant images
Returns:
Summarized narrative as a string
"""
narrative_data = tool_input.get('str_input', '')
if not narrative_data:
return "Error: No narrative memory data provided"
# Use the prompt template and LMM for narrative summarization
return self._call_lmm(tool_input)
class EpisodeSummarizationTool(BaseTool):
"""Tool for summarizing episodic memories."""
def execute(self, tool_input: Dict[str, Any]):
"""
Summarize episodic memories.
Args:
tool_input: Dictionary containing the episodic memory data
Expected to have 'str_input' key with the episodic memory data
May also have 'img_input' key with relevant images
Returns:
Summarized episode as a string
"""
episode_data = tool_input.get('str_input', '')
if not episode_data:
return "Error: No episodic memory data provided"
# Use the prompt template and LMM for episode summarization
return self._call_lmm(tool_input)
class TextSpanTool(BaseTool):
"""Tool for processing text spans."""
def execute(self, tool_input: Dict[str, Any]):
"""
Process text spans for a given input.
Args:
tool_input: Dictionary containing the text input
Expected to have 'str_input' key with the text content
May also have 'img_input' key with a screenshot
Returns:
Processed text spans as a string
"""
text = tool_input.get('str_input', '')
if not text:
return "Error: No text content provided"
# Use the prompt template and LMM for text span processing
return self._call_lmm(tool_input)
class DAGTranslatorTool(BaseTool):
"""Tool for translating task descriptions into a DAG (Directed Acyclic Graph) structure."""
def execute(self, tool_input: Dict[str, Any]):
"""
Translate task descriptions into a DAG structure.
Args:
tool_input: Dictionary containing the task description
Expected to have 'str_input' key with the task description
May also have 'img_input' key with a screenshot
Returns:
DAG representation as a string
"""
task = tool_input.get('str_input', '')
if not task:
return "Error: No task description provided"
# Use the prompt template and LMM for DAG translation
return self._call_lmm(tool_input)
class ObjectiveAlignmentTool(BaseTool):
"""Tool for aligning and rewriting user objective with current screen context."""
def execute(self, tool_input: Dict[str, Any]):
"""
Align ambiguous or high-level user objective with the current desktop screenshot context
and output a refined objective and assumptions.
Args:
tool_input: Dict with keys:
- 'str_input': the raw user objective or context text
- 'img_input': optional screenshot image content
Returns:
Refined objective as text (ideally JSON-structured), token count, and cost string
"""
text = tool_input.get('str_input', '')
if not text:
return "Error: No objective text provided", [0, 0, 0], ""
# Forward to LMM with the prompt template
return self._call_lmm(tool_input)
class TrajReflectorTool(BaseTool):
"""Tool for reflecting on execution trajectories."""
def execute(self, tool_input: Dict[str, Any]):
"""
Reflect on an execution trajectory.
Args:
tool_input: Dictionary containing the trajectory
Expected to have 'str_input' key with the trajectory
Returns:
Reflection as a string
"""
trajectory = tool_input.get('str_input', '')
if not trajectory:
return "Error: No trajectory provided"
# Use the prompt template and LMM for trajectory reflection
return self._call_lmm(tool_input)
class GroundingTool(BaseTool):
"""Tool for grounding agent actions in the environment."""
def execute(self, tool_input: Dict[str, Any]):
"""
Ground agent actions in the environment.
Args:
tool_input: Dictionary containing the action and environment state
Expected to have 'str_input' key with the action
Expected to have 'img_input' key with a screenshot
Returns:
Grounded action as a string
"""
action = tool_input.get('str_input', '')
screenshot = tool_input.get('img_input')
if not action:
return "Error: No action provided"
if not screenshot:
return "Error: No screenshot provided"
# Use the prompt template and LMM for action grounding
return self._call_lmm(tool_input)
def get_grounding_wh(self):
"""
Get grounding width and height based on provider and model name.
Returns:
If provider is doubao and model_name contains 'ui-tars', returns two values:
grounding_width (int): Width value (1024)
grounding_height (int): Height value (768)
Otherwise returns None, None
"""
if self.provider == "doubao" and ("ui-tars" in self.model_name or "ep-" in self.model_name):
grounding_width = 1000
grounding_height = 1000
return grounding_width, grounding_height
return None, None
class EvaluatorTool(BaseTool):
"""Tool for evaluating agent performance."""
def execute(self, tool_input: Dict[str, Any]):
"""
Evaluate agent performance.
Args:
tool_input: Dictionary containing the evaluation data
Expected to have 'str_input' key with the evaluation data
Returns:
Evaluation result as a string
"""
eval_data = tool_input.get('str_input', '')
if not eval_data:
return "Error: No evaluation data provided"
# Use the prompt template and LMM for performance evaluation
return self._call_lmm(tool_input)
class ActionGeneratorTool(BaseTool):
"""Tool for generating executable actions."""
def __init__(self, provider: str, model_name: str, tool_name: str, **kwargs):
"""
Initialize the action generator tool.
Args:
provider: API provider name
model_name: Model name to use
tool_name: Name of the tool (used as key in prompts.json)
**kwargs: Additional parameters, including:
enable_search: Whether to enable web search functionality
search_provider: Provider for web search (defaults to "bocha")
search_model: Model for web search (defaults to "")
"""
super().__init__(provider, model_name, tool_name)
# Extract search-related parameters
self.enable_search = kwargs.get("enable_search", False)
search_provider = kwargs.get("search_provider", "bocha")
search_model = kwargs.get("search_model", "")
# Initialize search tool if enabled
self.search_tool = None
if self.enable_search:
self.search_tool = WebSearchTool(search_provider, search_model, "")
logger.info(f"Web search enabled for {tool_name} using provider: {search_provider}")
def execute(self, tool_input: Dict[str, Any]):
"""
Generate executable actions.
Args:
tool_input: Dictionary containing the action request
Expected to have 'str_input' key with the action request
May also have 'img_input' key with a screenshot
Returns:
Generated action as a string
"""
action_request = tool_input.get('str_input', '')
if not action_request:
return "Error: No action request provided", [0, 0, 0], ""
# Check if search is enabled
if self.enable_search and self.search_tool:
try:
# Use the input text directly as search query
search_query = action_request
logger.info(f"Performing web search for query: {search_query}")
search_results, tokens, cost = self.search_tool.execute({"str_input": search_query})
# Enhance the action request with search results
enhanced_request = f"[Action Request]\n{action_request}\n[End of Action Request]\n\n[Web Search Results for '{action_request}']\n{search_results}\n\n[End of Web Search Results]"
tool_input["str_input"] = enhanced_request
logger.info(f"Search completed. Found information: {len(search_results)} characters")
except Exception as e:
logger.error(f"Error during web search: {e}")
# Continue with original request if search fails
# Use the prompt template and LMM for action generation
return self._call_lmm(tool_input)
class FastActionGeneratorTool(BaseTool):
"""Tool for directly generating executable actions without intermediate planning."""
def __init__(self, provider: str, model_name: str, tool_name: str, **kwargs):
"""
Initialize the fast action generator tool.
Args:
provider: API provider name
model_name: Model name to use
tool_name: Name of the tool (used as key in prompts.json)
**kwargs: Additional parameters, including:
enable_search: Whether to enable web search functionality
search_provider: Provider for web search (defaults to "bocha")
search_model: Model for web search (defaults to "")
"""
super().__init__(provider, model_name, tool_name)
# Extract search-related parameters
self.enable_search = kwargs.get("enable_search", False)
search_provider = kwargs.get("search_provider", "bocha")
search_model = kwargs.get("search_model", "")
# Initialize search tool if enabled
self.search_tool = None
if self.enable_search:
self.search_tool = WebSearchTool(search_provider, search_model, "")
logger.info(f"Web search enabled for {tool_name} using provider: {search_provider}")
def execute(self, tool_input: Dict[str, Any]):
"""
Generate executable actions directly from the instruction and screenshot.
Args:
tool_input: Dictionary containing the action request
Expected to have 'str_input' key with the instruction
Expected to have 'img_input' key with a screenshot
Returns:
Generated action as a string, token count, and cost
"""
action_request = tool_input.get('str_input', '')
screenshot = tool_input.get('img_input')
if not action_request:
return "Error: No action request provided", [0, 0, 0], ""
if not screenshot:
return "Error: No screenshot provided", [0, 0, 0], ""
# Check if search is enabled
if self.enable_search and self.search_tool:
try:
# Use the input text directly as search query
search_query = action_request
logger.info(f"Performing web search for query: {search_query}")
search_results, tokens, cost = self.search_tool.execute({"str_input": search_query})
# Enhance the action request with search results
enhanced_request = f"[Action Request]\n{action_request}\n[End of Action Request]\n\n[Web Search Results for '{action_request}']\n{search_results}\n\n[End of Web Search Results]"
tool_input["str_input"] = enhanced_request
logger.info(f"Search completed. Found information: {len(search_results)} characters")
except Exception as e:
logger.error(f"Error during web search: {e}")
# Continue with original request if search fails
# Use the prompt template and LMM for action generation
return self._call_lmm(tool_input)
def get_grounding_wh(self):
"""
Get grounding width and height based on provider and model name.
Returns:
If provider is doubao and model_name contains 'ui-tars', returns two values:
grounding_width (int): Width value (1024)
grounding_height (int): Height value (768)
Otherwise returns None, None
"""
if self.provider == "doubao" and "ui-tars" in self.model_name:
grounding_width = 1000
grounding_height = 1000
return grounding_width, grounding_height
return None, None
class EmbeddingTool(BaseTool):
"""Tool for generating text embeddings."""
def __init__(self, provider: str, model_name: str, tool_name: str):
"""
Initialize the embedding tool.
Args:
provider: API provider name (e.g., "openai", "gemini")
model_name: Model name to use
tool_name: Name of the tool (used as key in prompts.json)
"""
self.provider = provider
self.model_name = model_name
self.tool_name = tool_name
# Create EmbeddingAgent instance
self.engine_params = {
"engine_type": provider,
"embedding_model": model_name
}
# Initialize EmbeddingAgent
self.embedding_agent = EmbeddingAgent(engine_params=self.engine_params)
def execute(self, tool_input: Dict[str, Any]):
"""
Generate embeddings for the given text.
Args:
tool_input: Dictionary containing the text to embed
Expected to have 'str_input' key with the text
Returns:
Embeddings as a JSON string
"""
text = tool_input.get('str_input', '')
if not text:
return "Error: No text provided for embedding", [0, 0, 0], ""
try:
# Get embeddings for the text
embeddings, total_tokens, cost_string = self.embedding_agent.get_embeddings(text)
return embeddings, total_tokens, cost_string
except Exception as e:
logger.error(f"Error during embedding operation: {str(e)}")
return f"Error: Embedding operation failed: {str(e)}", [0, 0, 0], ""
class QueryFormulatorTool(BaseTool):
"""Tool for formulating queries from tasks or contexts."""
def execute(self, tool_input: Dict[str, Any]):
"""
Formulate a query for a given task or context.
Args:
tool_input: Dictionary containing the task or context description
Expected to have 'str_input' key with the description
May also have 'img_input' key with a screenshot
Returns:
Formulated query as a string
"""
task = tool_input.get('str_input', '')
if not task:
return "Error: No task or context description provided"
# Use the prompt template and LMM for query formulation
return self._call_lmm(tool_input)
class NewTools:
"""Main Tools class that provides access to all available tools."""
def __init__(self):
"""Initialize the Tools class."""
self.tools = {}
def register_tool(self, tool_name: str, provider: str, model_name: str, **kwargs):
"""
Register a tool with the specified parameters.
Args:
tool_name: Name of the tool to register
provider: API provider name
model_name: Model name to use
**kwargs: Additional parameters to pass to the tool
"""
tool: BaseTool = ToolFactory.create_tool(tool_name, provider, model_name, **kwargs)
self.tools[tool_name] = tool
def execute_tool(self, tool_name: str, tool_input: Dict[str, Any]):
"""
Execute a tool with the given input.
Args:
tool_name: Name of the tool to execute
tool_input: Input for the tool
Returns:
The output of the tool as a string
Raises:
ValueError: If the tool is not registered
"""
if tool_name not in self.tools:
raise ValueError(f"Tool {tool_name} is not registered")
return self.tools[tool_name].execute(tool_input)
def reset(self, tool_name: Optional[str] = None):
"""
Reset tools by resetting their llm_agent if available.
Args:
tool_name: Optional name of the specific tool to reset. If None, resets all tools.
"""
if tool_name is not None:
# Reset a specific tool
if tool_name not in self.tools:
raise ValueError(f"Tool {tool_name} is not registered")
tool = self.tools[tool_name]
if hasattr(tool, 'llm_agent') and tool.llm_agent is not None:
tool.llm_agent.reset()
else:
# Reset all tools
for tool in self.tools.values():
# Only reset if the tool has an llm_agent attribute
if hasattr(tool, 'llm_agent') and tool.llm_agent is not None:
tool.llm_agent.reset()