import base64 import json import logging import os import re import time from io import BytesIO from typing import Dict, List import backoff import openai import requests from PIL import Image from google.api_core.exceptions import ( InvalidArgument, ResourceExhausted, InternalServerError, BadRequest, ) from requests.exceptions import SSLError logger = None OPENAI_API_KEY = "Your OpenAI API Key" JEDI_API_KEY = "Your Jedi API Key" JEDI_SERVICE_URL = "Your Jedi Service URL" from mm_agents.prompts import JEDI_PLANNER_SYS_PROMPT, JEDI_GROUNDER_SYS_PROMPT from mm_agents.img_utils import smart_resize def encode_image(image_content): return base64.b64encode(image_content).decode("utf-8") class JediAgent7B: def __init__( self, platform="ubuntu", planner_model="gpt-4o", executor_model="jedi-7b", max_tokens=1500, top_p=0.9, temperature=0.5, action_space="pyautogui", observation_type="screenshot", max_steps=15 ): self.platform = platform self.planner_model = planner_model self.executor_model = executor_model assert self.executor_model is not None, "Executor model cannot be None" self.max_tokens = max_tokens self.top_p = top_p self.temperature = temperature self.action_space = action_space self.observation_type = observation_type assert action_space in ["pyautogui"], "Invalid action space" assert observation_type in ["screenshot"], "Invalid observation type" self.thoughts = [] self.actions = [] self.observations = [] self.observation_captions = [] self.max_image_history_length = 5 self.current_step = 1 self.max_steps = max_steps def predict(self, instruction: str, obs: Dict) -> List: """ Predict the next action(s) based on the current observation. """ # get the width and height of the screenshot image = Image.open(BytesIO(obs["screenshot"])) width, height = image.convert("RGB").size previous_actions = ("\n".join([ f"Step {i+1}: {action}" for i, action in enumerate(self.actions) ]) if self.actions else "None") user_prompt = ( f"""Please generate the next move according to the UI screenshot and instruction. And you can refer to the previous actions and observations for reflection.\n\nInstruction: {instruction}\n\n""") messages = [{ "role": "system", "content": [{ "type": "text", "text": JEDI_PLANNER_SYS_PROMPT.replace("{current_step}", str(self.current_step)).replace("{max_steps}", str(self.max_steps)) }] }] # Determine which observations to include images for (only most recent ones) obs_start_idx = max(0, len(self.observations) - self.max_image_history_length) # Add all thought and action history for i in range(len(self.thoughts)): # For recent steps, include the actual screenshot if i >= obs_start_idx: messages.append({ "role": "user", "content": [{ "type": "image_url", "image_url": { "url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}", "detail": "high" }, }] }) # For older steps, use the observation caption instead of the image else: messages.append({ "role": "user", "content": [{ "type": "text", "text": f"Observation: {self.observation_captions[i]}" }] }) thought_messages = f"Thought:\n{self.thoughts[i]}" action_messages = f"Action:" for action in self.actions[i]: action_messages += f"\n{action}" messages.append({ "role": "assistant", "content": [{ "type": "text", "text": thought_messages + "\n" + action_messages }] }) #print(thought_messages + "\n" + action_messages) messages.append({ "role":"user", "content": [ { "type":"image_url", "image_url":{ "url":f"data:image/png;base64,{encode_image(obs['screenshot'])}", "detail": "high" }, }, { "type": "text", "text": user_prompt }, ], }) planner_response = self.call_llm( { "model": self.planner_model, "messages": messages, "max_tokens": self.max_tokens, "top_p": self.top_p, "temperature": self.temperature, }, self.planner_model, ) logger.info(f"Planner Output: {planner_response}") codes = self.parse_code_from_planner_response(planner_response) # Add retry logic if no codes were parsed retry_count = 0 max_retries = 5 while not codes and retry_count < max_retries: logger.info(f"No codes parsed from planner response. Retrying ({retry_count+1}/{max_retries})...") messages.append({ "role": "user", "content": [ {"type": "text", "text": "You didn't generate valid actions. Please try again."} ] }) planner_response = self.call_llm( { "model": self.planner_model, "messages": messages, "max_tokens": self.max_tokens, "top_p": self.top_p, "temperature": self.temperature, }, self.planner_model, ) logger.info(f"Retry Planner Output: {planner_response}") codes = self.parse_code_from_planner_response(planner_response) retry_count += 1 thought = self.parse_thought_from_planner_response(planner_response) observation_caption = self.parse_observation_caption_from_planner_response(planner_response) resized_height, resized_width = smart_resize(height, width, max_pixels= 2700 * 28 * 28) pyautogui_actions = [] for line in codes: code = self.convert_action_to_grounding_model_instruction( line, obs, instruction, height, width, resized_height, resized_width ) pyautogui_actions.append(code) self.actions.append([pyautogui_actions]) self.observations.append(obs) self.thoughts.append(thought) self.observation_captions.append(observation_caption) self.current_step += 1 return planner_response, pyautogui_actions, {} def parse_observation_caption_from_planner_response(self, input_string: str) -> str: pattern = r"Observation:\n(.*?)\n" matches = re.findall(pattern, input_string, re.DOTALL) if matches: return matches[0].strip() return "" def parse_thought_from_planner_response(self, input_string: str) -> str: pattern = r"Thought:\n(.*?)\n" matches = re.findall(pattern, input_string, re.DOTALL) if matches: return matches[0].strip() return "" def parse_code_from_planner_response(self, input_string: str) -> List[str]: input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()]) if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: return [input_string.strip()] pattern = r"```(?:\w+\s+)?(.*?)```" matches = re.findall(pattern, input_string, re.DOTALL) codes = [] for match in matches: match = match.strip() commands = ['WAIT', 'DONE', 'FAIL'] if match in commands: codes.append(match.strip()) elif match.split('\n')[-1] in commands: if len(match.split('\n')) > 1: codes.append("\n".join(match.split('\n')[:-1])) codes.append(match.split('\n')[-1]) else: codes.append(match) return codes def convert_action_to_grounding_model_instruction(self, line: str, obs: Dict, instruction: str, height: int, width: int, resized_height: int, resized_width: int ) -> str: pattern = r'(#.*?)\n(pyautogui\.(moveTo|click|rightClick|doubleClick|middleClick|dragTo)\((?:x=)?(\d+)(?:,\s*|\s*,\s*y=)(\d+)(?:,\s*duration=[\d.]+)?\))' matches = re.findall(pattern, line, re.DOTALL) if not matches: return line new_instruction = line for match in matches: comment = match[0].split("#")[1].strip() original_action = match[1] func_name = match[2].strip() if "click()" in original_action.lower(): continue messages = [] messages.append({ "role": "system", "content": [{"type": "text", "text": JEDI_GROUNDER_SYS_PROMPT.replace("{height}", str(resized_height)).replace("{width}", str(resized_width))}] }) messages.append( { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{encode_image(obs['screenshot'])}", "detail": "high", }, }, { "type": "text", "text": '\n' + comment, }, ], } ) grounding_response = self.call_llm({ "model": self.executor_model, "messages": messages, "max_tokens": self.max_tokens, "top_p": self.top_p, "temperature": self.temperature }, self.executor_model) coordinates = self.parse_jedi_response(grounding_response, width, height, resized_width, resized_height) logger.info(coordinates) if coordinates == [-1, -1]: continue action_parts = original_action.split('(') new_action = f"{action_parts[0]}({coordinates[0]}, {coordinates[1]}" if len(action_parts) > 1 and 'duration' in action_parts[1]: duration_part = action_parts[1].split(',')[-1] new_action += f", {duration_part}" elif len(action_parts) > 1 and 'button' in action_parts[1]: button_part = action_parts[1].split(',')[-1] new_action += f", {button_part}" else: new_action += ")" logger.info(new_action) new_instruction = new_instruction.replace(original_action, new_action) return new_instruction def parse_jedi_response(self, response, width: int, height: int, resized_width: int, resized_height: int) -> List[str]: """ Parse the LLM response and convert it to low level action and pyautogui code. """ low_level_instruction = "" pyautogui_code = [] try: # 定义可能的标签组合 start_tags = ["", "⚗"] end_tags = ["", "⚗"] # 找到有效的开始和结束标签 start_tag = next((tag for tag in start_tags if tag in response), None) end_tag = next((tag for tag in end_tags if tag in response), None) if not start_tag or not end_tag: print("The response is missing valid start or end tags") return low_level_instruction, pyautogui_code # 分割响应以提取low_level_instruction和tool_call parts = response.split(start_tag) if len(parts) < 2: print("The response is missing the start tag") return low_level_instruction, pyautogui_code low_level_instruction = parts[0].strip().replace("Action: ", "") tool_call_str = parts[1].split(end_tag)[0].strip() try: tool_call = json.loads(tool_call_str) action = tool_call.get("arguments", {}).get("action", "") args = tool_call.get("arguments", {}) except json.JSONDecodeError as e: print(f"JSON parsing error: {e}") # 处理解析错误,返回默认值或空值 action = "" args = {} # convert the coordinate to the original resolution x = int(args.get("coordinate", [-1, -1])[0] * width / resized_width) y = int(args.get("coordinate", [-1, -1])[1] * height / resized_height) return [x, y] except Exception as e: logger.error(f"Failed to parse response: {e}") return [-1, -1] @backoff.on_exception( backoff.constant, # here you should add more model exceptions as you want, # but you are forbidden to add "Exception", that is, a common type of exception # because we want to catch this kind of Exception in the outside to ensure # each example won't exceed the time limit ( # General exceptions SSLError, # OpenAI exceptions openai.RateLimitError, openai.BadRequestError, openai.InternalServerError, # Google exceptions InvalidArgument, ResourceExhausted, InternalServerError, BadRequest, # Groq exceptions # todo: check ), interval=30, max_tries=10, ) def call_llm(self, payload, model): if model.startswith("gpt"): headers = { "Content-Type": "application/json", "Authorization": f"Bearer {OPENAI_API_KEY}" } logger.info("Generating content with GPT model: %s", model) response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, json=payload, ) if response.status_code != 200: logger.error("Failed to call LLM: " + response.text) time.sleep(5) return "" else: return response.json()["choices"][0]["message"]["content"] elif model.startswith("jedi"): headers = { "Content-Type": "application/json", "Authorization": f"Bearer {JEDI_API_KEY}" } response = requests.post( f"{JEDI_SERVICE_URL}/v1/chat/completions", headers=headers, json=payload, ) if response.status_code != 200: logger.error("Failed to call LLM: " + response.text) time.sleep(5) return "" else: return response.json()["choices"][0]["message"]["content"] def reset(self, _logger=None): global logger logger = (_logger if _logger is not None else logging.getLogger("desktopenv.jedi_7b_agent")) self.thoughts = [] self.action_descriptions = [] self.actions = [] self.observations = [] self.observation_captions = []