/
OS-World51f5dde
import base64
import json
import logging
import os
import re
import time
from io import BytesIO
from typing import Dict, List
import backoff
import openai
import requests
from PIL import Image
from google.api_core.exceptions import (
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
)
from requests.exceptions import SSLError
logger = None
OPENAI_API_KEY = "Your OpenAI API Key"
JEDI_API_KEY = "Your Jedi API Key"
JEDI_SERVICE_URL = "Your Jedi Service URL"
from mm_agents.prompts import JEDI_PLANNER_SYS_PROMPT, JEDI_GROUNDER_SYS_PROMPT
from mm_agents.img_utils import smart_resize
def encode_image(image_content):
return base64.b64encode(image_content).decode("utf-8")
class JediAgent7B:
def __init__(
self,
platform="ubuntu",
planner_model="gpt-4o",
executor_model="jedi-7b",
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="pyautogui",
observation_type="screenshot",
max_steps=15
):
self.platform = platform
self.planner_model = planner_model
self.executor_model = executor_model
assert self.executor_model is not None, "Executor model cannot be None"
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
self.thoughts = []
self.actions = []
self.observations = []
self.observation_captions = []
self.max_image_history_length = 5
self.current_step = 1
self.max_steps = max_steps
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
# get the width and height of the screenshot
image = Image.open(BytesIO(obs["screenshot"]))
width, height = image.convert("RGB").size
previous_actions = ("\n".join([
f"Step {i+1}: {action}" for i, action in enumerate(self.actions)
]) if self.actions else "None")
user_prompt = (
f"""Please generate the next move according to the UI screenshot and instruction. And you can refer to the previous actions and observations for reflection.\n\nInstruction: {instruction}\n\n""")
messages = [{
"role": "system",
"content": [{
"type": "text",
"text": JEDI_PLANNER_SYS_PROMPT.replace("{current_step}", str(self.current_step)).replace("{max_steps}", str(self.max_steps))
}]
}]
# Determine which observations to include images for (only most recent ones)
obs_start_idx = max(0, len(self.observations) - self.max_image_history_length)
# Add all thought and action history
for i in range(len(self.thoughts)):
# For recent steps, include the actual screenshot
if i >= obs_start_idx:
messages.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}",
"detail": "high"
},
}]
})
# For older steps, use the observation caption instead of the image
else:
messages.append({
"role": "user",
"content": [{
"type": "text",
"text": f"Observation: {self.observation_captions[i]}"
}]
})
thought_messages = f"Thought:\n{self.thoughts[i]}"
action_messages = f"Action:"
for action in self.actions[i]:
action_messages += f"\n{action}"
messages.append({
"role": "assistant",
"content": [{
"type": "text",
"text": thought_messages + "\n" + action_messages
}]
})
#print(thought_messages + "\n" + action_messages)
messages.append({
"role":"user",
"content": [
{
"type":"image_url",
"image_url":{
"url":f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high"
},
},
{
"type": "text",
"text": user_prompt
},
],
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
# Add retry logic if no codes were parsed
retry_count = 0
max_retries = 5
while not codes and retry_count < max_retries:
logger.info(f"No codes parsed from planner response. Retrying ({retry_count+1}/{max_retries})...")
messages.append({
"role": "user",
"content": [
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
]
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Retry Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
retry_count += 1
thought = self.parse_thought_from_planner_response(planner_response)
observation_caption = self.parse_observation_caption_from_planner_response(planner_response)
resized_height, resized_width = smart_resize(height, width, max_pixels= 2700 * 28 * 28)
pyautogui_actions = []
for line in codes:
code = self.convert_action_to_grounding_model_instruction(
line,
obs,
instruction,
height,
width,
resized_height,
resized_width
)
pyautogui_actions.append(code)
self.actions.append([pyautogui_actions])
self.observations.append(obs)
self.thoughts.append(thought)
self.observation_captions.append(observation_caption)
self.current_step += 1
return planner_response, pyautogui_actions, {}
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_thought_from_planner_response(self, input_string: str) -> str:
pattern = r"Thought:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_code_from_planner_response(self, input_string: str) -> List[str]:
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
pattern = r"```(?:\w+\s+)?(.*?)```"
matches = re.findall(pattern, input_string, re.DOTALL)
codes = []
for match in matches:
match = match.strip()
commands = ['WAIT', 'DONE', 'FAIL']
if match in commands:
codes.append(match.strip())
elif match.split('\n')[-1] in commands:
if len(match.split('\n')) > 1:
codes.append("\n".join(match.split('\n')[:-1]))
codes.append(match.split('\n')[-1])
else:
codes.append(match)
return codes
def convert_action_to_grounding_model_instruction(self, line: str, obs: Dict, instruction: str, height: int, width: int, resized_height: int, resized_width: int ) -> str:
pattern = r'(#.*?)\n(pyautogui\.(moveTo|click|rightClick|doubleClick|middleClick|dragTo)\((?:x=)?(\d+)(?:,\s*|\s*,\s*y=)(\d+)(?:,\s*duration=[\d.]+)?\))'
matches = re.findall(pattern, line, re.DOTALL)
if not matches:
return line
new_instruction = line
for match in matches:
comment = match[0].split("#")[1].strip()
original_action = match[1]
func_name = match[2].strip()
if "click()" in original_action.lower():
continue
messages = []
messages.append({
"role": "system",
"content": [{"type": "text", "text": JEDI_GROUNDER_SYS_PROMPT.replace("{height}", str(resized_height)).replace("{width}", str(resized_width))}]
})
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high",
},
},
{
"type": "text",
"text": '\n' + comment,
},
],
}
)
grounding_response = self.call_llm({
"model": self.executor_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}, self.executor_model)
coordinates = self.parse_jedi_response(grounding_response, width, height, resized_width, resized_height)
logger.info(coordinates)
if coordinates == [-1, -1]:
continue
action_parts = original_action.split('(')
new_action = f"{action_parts[0]}({coordinates[0]}, {coordinates[1]}"
if len(action_parts) > 1 and 'duration' in action_parts[1]:
duration_part = action_parts[1].split(',')[-1]
new_action += f", {duration_part}"
elif len(action_parts) > 1 and 'button' in action_parts[1]:
button_part = action_parts[1].split(',')[-1]
new_action += f", {button_part}"
else:
new_action += ")"
logger.info(new_action)
new_instruction = new_instruction.replace(original_action, new_action)
return new_instruction
def parse_jedi_response(self, response, width: int, height: int, resized_width: int, resized_height: int) -> List[str]:
"""
Parse the LLM response and convert it to low level action and pyautogui code.
"""
low_level_instruction = ""
pyautogui_code = []
try:
# 定义可能的标签组合
start_tags = ["<tool_call>", "⚗"]
end_tags = ["</tool_call>", "⚗"]
# 找到有效的开始和结束标签
start_tag = next((tag for tag in start_tags if tag in response), None)
end_tag = next((tag for tag in end_tags if tag in response), None)
if not start_tag or not end_tag:
print("The response is missing valid start or end tags")
return low_level_instruction, pyautogui_code
# 分割响应以提取low_level_instruction和tool_call
parts = response.split(start_tag)
if len(parts) < 2:
print("The response is missing the start tag")
return low_level_instruction, pyautogui_code
low_level_instruction = parts[0].strip().replace("Action: ", "")
tool_call_str = parts[1].split(end_tag)[0].strip()
try:
tool_call = json.loads(tool_call_str)
action = tool_call.get("arguments", {}).get("action", "")
args = tool_call.get("arguments", {})
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
# 处理解析错误,返回默认值或空值
action = ""
args = {}
# convert the coordinate to the original resolution
x = int(args.get("coordinate", [-1, -1])[0] * width / resized_width)
y = int(args.get("coordinate", [-1, -1])[1] * height / resized_height)
return [x, y]
except Exception as e:
logger.error(f"Failed to parse response: {e}")
return [-1, -1]
@backoff.on_exception(
backoff.constant,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure
# each example won't exceed the time limit
(
# General exceptions
SSLError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
# Groq exceptions
# todo: check
),
interval=30,
max_tries=10,
)
def call_llm(self, payload, model):
if model.startswith("gpt"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"
}
logger.info("Generating content with GPT model: %s", model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()["choices"][0]["message"]["content"]
elif model.startswith("jedi"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {JEDI_API_KEY}"
}
response = requests.post(
f"{JEDI_SERVICE_URL}/v1/chat/completions",
headers=headers,
json=payload,
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()["choices"][0]["message"]["content"]
def reset(self, _logger=None):
global logger
logger = (_logger if _logger is not None else
logging.getLogger("desktopenv.jedi_7b_agent"))
self.thoughts = []
self.action_descriptions = []
self.actions = []
self.observations = []
self.observation_captions = []