mirrored a minute ago
0
MillanKAdd Jedi agent implementation to mm_agents (#192) * feat: implement Jedi agent * chore: code clean51f5dde
import base64
import json
import logging
import os
import re
import time
from io import BytesIO
from typing import Dict, List

import backoff
import openai
import requests
from PIL import Image
from google.api_core.exceptions import (
    InvalidArgument,
    ResourceExhausted,
    InternalServerError,
    BadRequest,
)
from requests.exceptions import SSLError

logger = None

OPENAI_API_KEY = "Your OpenAI API Key"
JEDI_API_KEY = "Your Jedi API Key"
JEDI_SERVICE_URL = "Your Jedi Service URL"

from mm_agents.prompts import JEDI_PLANNER_SYS_PROMPT, JEDI_GROUNDER_SYS_PROMPT
from mm_agents.img_utils import smart_resize

def encode_image(image_content):
    return base64.b64encode(image_content).decode("utf-8")

class JediAgent7B:
    def __init__(
        self,
        platform="ubuntu",
        planner_model="gpt-4o",
        executor_model="jedi-7b",
        max_tokens=1500,
        top_p=0.9,
        temperature=0.5,
        action_space="pyautogui",
        observation_type="screenshot",
        max_steps=15
    ):
        self.platform = platform
        self.planner_model = planner_model
        self.executor_model = executor_model
        assert self.executor_model is not None, "Executor model cannot be None"
        self.max_tokens = max_tokens
        self.top_p = top_p
        self.temperature = temperature
        self.action_space = action_space
        self.observation_type = observation_type
        assert action_space in ["pyautogui"], "Invalid action space"
        assert observation_type in ["screenshot"], "Invalid observation type"
        self.thoughts = []
        self.actions = []
        self.observations = []
        self.observation_captions = []
        self.max_image_history_length = 5
        self.current_step = 1
        self.max_steps = max_steps

    def predict(self, instruction: str, obs: Dict) -> List:
        """
        Predict the next action(s) based on the current observation.
        """

        # get the width and height of the screenshot
        image = Image.open(BytesIO(obs["screenshot"]))
        width, height = image.convert("RGB").size

        previous_actions = ("\n".join([
            f"Step {i+1}: {action}" for i, action in enumerate(self.actions)
        ]) if self.actions else "None")

        user_prompt = (
            f"""Please generate the next move according to the UI screenshot and instruction. And you can refer to the previous actions and observations for reflection.\n\nInstruction: {instruction}\n\n""")

        messages = [{
            "role": "system",
            "content": [{
                "type": "text",
                "text": JEDI_PLANNER_SYS_PROMPT.replace("{current_step}", str(self.current_step)).replace("{max_steps}", str(self.max_steps))
            }]
        }]

        # Determine which observations to include images for (only most recent ones)
        obs_start_idx = max(0, len(self.observations) - self.max_image_history_length)
        
        # Add all thought and action history
        for i in range(len(self.thoughts)):
            # For recent steps, include the actual screenshot
            if i >= obs_start_idx:
                messages.append({
                    "role": "user",
                    "content": [{
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}",
                            "detail": "high"
                        },
                    }]
                })
            # For older steps, use the observation caption instead of the image
            else:
                messages.append({
                    "role": "user",
                    "content": [{
                        "type": "text",
                        "text": f"Observation: {self.observation_captions[i]}"
                    }]
                })

            thought_messages = f"Thought:\n{self.thoughts[i]}"

            action_messages = f"Action:"
            for action in self.actions[i]:
                action_messages += f"\n{action}"
            messages.append({
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": thought_messages + "\n" + action_messages
                }]
            })
            #print(thought_messages + "\n" + action_messages)

        messages.append({
            "role":"user",
            "content": [
                {
                    "type":"image_url",
                    "image_url":{
                        "url":f"data:image/png;base64,{encode_image(obs['screenshot'])}",
                        "detail": "high"
                    },
                },
                {
                    "type": "text",
                    "text": user_prompt
                },
            ],
        })

        planner_response = self.call_llm(
            {
                "model": self.planner_model,
                "messages": messages,
                "max_tokens": self.max_tokens,
                "top_p": self.top_p,
                "temperature": self.temperature,
            },
            self.planner_model,
        )

        logger.info(f"Planner Output: {planner_response}")
        codes = self.parse_code_from_planner_response(planner_response)
        # Add retry logic if no codes were parsed
        retry_count = 0
        max_retries = 5
        while not codes and retry_count < max_retries:
            logger.info(f"No codes parsed from planner response. Retrying ({retry_count+1}/{max_retries})...")
            messages.append({
                "role": "user",
                "content": [
                    {"type": "text", "text": "You didn't generate valid actions. Please try again."}
                ]
            })
            planner_response = self.call_llm(
                {
                    "model": self.planner_model,
                    "messages": messages,
                    "max_tokens": self.max_tokens,
                    "top_p": self.top_p,
                    "temperature": self.temperature,
                },
                self.planner_model,
            )
            logger.info(f"Retry Planner Output: {planner_response}")
            codes = self.parse_code_from_planner_response(planner_response)
            retry_count += 1
            
        thought = self.parse_thought_from_planner_response(planner_response)
        observation_caption = self.parse_observation_caption_from_planner_response(planner_response)
        resized_height, resized_width = smart_resize(height, width, max_pixels= 2700 * 28 * 28)
        pyautogui_actions = []
        for line in codes:
            code = self.convert_action_to_grounding_model_instruction(
                line,
                obs,
                instruction,
                height,
                width,
                resized_height,
                resized_width
            )
            pyautogui_actions.append(code)
        self.actions.append([pyautogui_actions])
        self.observations.append(obs)
        self.thoughts.append(thought)
        self.observation_captions.append(observation_caption)
        self.current_step += 1
        return planner_response, pyautogui_actions, {}
        
    def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
        pattern = r"Observation:\n(.*?)\n"
        matches = re.findall(pattern, input_string, re.DOTALL)
        if matches:
            return matches[0].strip()
        return ""

    def parse_thought_from_planner_response(self, input_string: str) -> str:
        pattern = r"Thought:\n(.*?)\n"
        matches = re.findall(pattern, input_string, re.DOTALL)
        if matches:
            return matches[0].strip()
        return ""

    def parse_code_from_planner_response(self, input_string: str) -> List[str]:

        input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
        if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
            return [input_string.strip()]

        pattern = r"```(?:\w+\s+)?(.*?)```"
        matches = re.findall(pattern, input_string, re.DOTALL)
        codes = []

        for match in matches:
            match = match.strip()
            commands = ['WAIT', 'DONE', 'FAIL']

            if match in commands:
                codes.append(match.strip())
            elif match.split('\n')[-1] in commands:
                if len(match.split('\n')) > 1:
                    codes.append("\n".join(match.split('\n')[:-1]))
                codes.append(match.split('\n')[-1])
            else:
                codes.append(match)

        return codes

    def convert_action_to_grounding_model_instruction(self, line: str, obs: Dict, instruction: str, height: int, width: int, resized_height: int, resized_width: int ) -> str:
        pattern = r'(#.*?)\n(pyautogui\.(moveTo|click|rightClick|doubleClick|middleClick|dragTo)\((?:x=)?(\d+)(?:,\s*|\s*,\s*y=)(\d+)(?:,\s*duration=[\d.]+)?\))'
        matches = re.findall(pattern, line, re.DOTALL)
        if not matches:
            return line
        new_instruction = line
        for match in matches:
            comment = match[0].split("#")[1].strip()
            original_action = match[1]
            func_name = match[2].strip()

            if "click()" in original_action.lower():
                continue
            
            messages = []
            messages.append({
                "role": "system",
                "content": [{"type": "text", "text": JEDI_GROUNDER_SYS_PROMPT.replace("{height}", str(resized_height)).replace("{width}", str(resized_width))}]
            })
            messages.append(
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{encode_image(obs['screenshot'])}",
                                "detail": "high",
                            },
                        },
                        {
                            "type": "text",
                            "text": '\n' + comment,
                        },
                    ],
                }
            )
            grounding_response = self.call_llm({
                "model": self.executor_model,
                "messages": messages,
                "max_tokens": self.max_tokens,
                "top_p": self.top_p,
                "temperature": self.temperature
            }, self.executor_model)
            coordinates = self.parse_jedi_response(grounding_response, width, height, resized_width, resized_height)
            logger.info(coordinates)
            if coordinates == [-1, -1]:
                continue
            action_parts = original_action.split('(')
            new_action = f"{action_parts[0]}({coordinates[0]}, {coordinates[1]}"
            if len(action_parts) > 1 and 'duration' in action_parts[1]:
                duration_part = action_parts[1].split(',')[-1]
                new_action += f", {duration_part}"
            elif len(action_parts) > 1 and 'button' in action_parts[1]:
                button_part = action_parts[1].split(',')[-1]
                new_action += f", {button_part}"
            else:
                new_action += ")"
            logger.info(new_action)
            new_instruction = new_instruction.replace(original_action, new_action)
        return new_instruction
        
    def parse_jedi_response(self, response, width: int, height: int, resized_width: int, resized_height: int) -> List[str]:
        """
        Parse the LLM response and convert it to low level action and pyautogui code.
        """ 

        low_level_instruction = ""
        pyautogui_code = []
        try:
            # 定义可能的标签组合
            start_tags = ["<tool_call>", ""]
            end_tags = ["</tool_call>", ""]

            # 找到有效的开始和结束标签
            start_tag = next((tag for tag in start_tags if tag in response), None)
            end_tag = next((tag for tag in end_tags if tag in response), None)

            if not start_tag or not end_tag:
                print("The response is missing valid start or end tags")
                return low_level_instruction, pyautogui_code

            # 分割响应以提取low_level_instruction和tool_call
            parts = response.split(start_tag)
            if len(parts) < 2:
                print("The response is missing the start tag")
                return low_level_instruction, pyautogui_code

            low_level_instruction = parts[0].strip().replace("Action: ", "")
            tool_call_str = parts[1].split(end_tag)[0].strip()

            try:
                tool_call = json.loads(tool_call_str)
                action = tool_call.get("arguments", {}).get("action", "")
                args = tool_call.get("arguments", {})
            except json.JSONDecodeError as e:
                print(f"JSON parsing error: {e}")
                # 处理解析错误,返回默认值或空值
                action = ""
                args = {}
            
            # convert the coordinate to the original resolution
            x = int(args.get("coordinate", [-1, -1])[0] * width / resized_width)
            y = int(args.get("coordinate", [-1, -1])[1] * height / resized_height)

            return [x, y]
        except Exception as e:
            logger.error(f"Failed to parse response: {e}")
            return [-1, -1]

    @backoff.on_exception(
        backoff.constant,
        # here you should add more model exceptions as you want,
        # but you are forbidden to add "Exception", that is, a common type of exception
        # because we want to catch this kind of Exception in the outside to ensure
        # each example won't exceed the time limit
        (
            # General exceptions
            SSLError,
            # OpenAI exceptions
            openai.RateLimitError,
            openai.BadRequestError,
            openai.InternalServerError,
            # Google exceptions
            InvalidArgument,
            ResourceExhausted,
            InternalServerError,
            BadRequest,
            # Groq exceptions
            # todo: check
        ),
        interval=30,
        max_tries=10,
    )
    def call_llm(self, payload, model):
        if model.startswith("gpt"):
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {OPENAI_API_KEY}"
            }
            logger.info("Generating content with GPT model: %s", model)
            response = requests.post(
                "https://api.openai.com/v1/chat/completions",
                headers=headers,
                json=payload,
            )

            if response.status_code != 200:
                logger.error("Failed to call LLM: " + response.text)
                time.sleep(5)
                return ""
            else:
                return response.json()["choices"][0]["message"]["content"]

        elif model.startswith("jedi"):
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {JEDI_API_KEY}"
            }
            response = requests.post(
                f"{JEDI_SERVICE_URL}/v1/chat/completions",
                headers=headers,
                json=payload,
            )
            if response.status_code != 200:
                logger.error("Failed to call LLM: " + response.text)
                time.sleep(5)
                return ""
            else:
                return response.json()["choices"][0]["message"]["content"]

    def reset(self, _logger=None):
        global logger
        logger = (_logger if _logger is not None else
                  logging.getLogger("desktopenv.jedi_7b_agent"))

        self.thoughts = []
        self.action_descriptions = []
        self.actions = []
        self.observations = []
        self.observation_captions = []