/
OS-World0f00788
import base64
import logging
import os
import re
from io import BytesIO
from typing import Dict, List
import backoff
import openai
import requests
from PIL import Image
from requests.exceptions import SSLError
from mm_agents.prompts import O3_SYSTEM_PROMPT
logger = None
MAX_RETRY_TIMES = 10
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY",None) #"Your OpenAI API Key"
def encode_image(image_content):
return base64.b64encode(image_content).decode("utf-8")
class O3Agent:
def __init__(
self,
platform="ubuntu",
model="o3",
max_tokens=1500,
client_password="password",
action_space="pyautogui",
observation_type="screenshot",
max_steps=15
):
self.platform = platform
self.model = model
self.max_tokens = max_tokens
self.client_password = client_password
self.action_space = action_space
self.observation_type = observation_type
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
self.thoughts = []
self.actions = []
self.observations = []
self.observation_captions = []
self.max_image_history_length = 5
self.current_step = 1
self.max_steps = max_steps
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
user_prompt = (
f"""Please generate the next move according to the UI screenshot and instruction. And you can refer to the previous actions and observations for reflection.\n\nInstruction: {instruction}\n\n""")
messages = [{
"role": "system",
"content": [{
"type": "text",
"text": O3_SYSTEM_PROMPT.format(
current_step=self.current_step,
max_steps=self.max_steps,
CLIENT_PASSWORD=self.client_password
)
}]
}]
# Determine which observations to include images for (only most recent ones)
obs_start_idx = max(0, len(self.observations) - self.max_image_history_length)
# Add all thought and action history
for i in range(len(self.thoughts)):
# For recent steps, include the actual screenshot
if i >= obs_start_idx:
messages.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}",
"detail": "high"
},
}]
})
# For older steps, use the observation caption instead of the image
else:
messages.append({
"role": "user",
"content": [{
"type": "text",
"text": f"Observation: {self.observation_captions[i]}"
}]
})
thought_messages = f"Thought:\n{self.thoughts[i]}"
action_messages = f"Action:"
for action in self.actions[i]:
action_messages += f"\n{action}"
messages.append({
"role": "assistant",
"content": [{
"type": "text",
"text": thought_messages + "\n" + action_messages
}]
})
messages.append({
"role":"user",
"content": [
{
"type":"image_url",
"image_url":{
"url":f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high"
},
},
{
"type": "text",
"text": user_prompt
},
],
})
response = self.call_llm(
{
"model": self.model,
"messages": messages,
"max_completion_tokens": self.max_tokens,
},
self.model,
)
logger.info(f"Output: {response}")
codes = self.parse_code_from_planner_response(response)
# Add retry logic if no codes were parsed
retry_count = 0
max_retries = MAX_RETRY_TIMES
while not codes and retry_count < max_retries:
logger.info(f"No codes parsed from planner response. Retrying ({retry_count+1}/{max_retries})...")
messages.append({
"role": "user",
"content": [
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
]
})
response = self.call_llm(
{
"model": self.model,
"messages": messages,
"max_completion_tokens": self.max_tokens,
},
self.model,
)
logger.info(f"Retry Planner Output: {response}")
codes = self.parse_code_from_planner_response(response)
retry_count += 1
thought = self.parse_thought_from_planner_response(response)
observation_caption = self.parse_observation_caption_from_planner_response(response)
logger.info(f"Thought: {thought}")
logger.info(f"Observation Caption: {observation_caption}")
logger.info(f"Codes: {codes}")
self.actions.append([codes])
self.observations.append(obs)
self.thoughts.append(thought)
self.observation_captions.append(observation_caption)
self.current_step += 1
return response, codes
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_thought_from_planner_response(self, input_string: str) -> str:
pattern = r"Thought:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_code_from_planner_response(self, input_string: str) -> List[str]:
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
pattern = r"```(?:\w+\s+)?(.*?)```"
matches = re.findall(pattern, input_string, re.DOTALL)
codes = []
for match in matches:
match = match.strip()
commands = ['WAIT', 'DONE', 'FAIL']
if match in commands:
codes.append(match.strip())
elif match.split('\n')[-1] in commands:
if len(match.split('\n')) > 1:
codes.append("\n".join(match.split('\n')[:-1]))
codes.append(match.split('\n')[-1])
else:
codes.append(match)
return codes
@backoff.on_exception(
backoff.constant,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure
# each example won't exceed the time limit
(
# General exceptions
SSLError,
requests.HTTPError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
openai.APIConnectionError,
openai.APIError
),
interval=30,
max_tries=10,
)
def call_llm(self, payload, model):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"
}
logger.info("Generating content with GPT model: %s", model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
# Raise HTTPError to trigger backoff retry mechanism
response.raise_for_status()
else:
return response.json()["choices"][0]["message"]["content"]
def reset(self, _logger=None):
global logger
logger = (_logger if _logger is not None else
logging.getLogger("desktopenv.o3_agent"))
self.thoughts = []
self.action_descriptions = []
self.actions = []
self.observations = []
self.observation_captions = []