mirrored 12 minutes ago
0
Howieadd support for mobile agent v3 (#328) * add support for mobile agent v3 * add mobile_agent * add support for mobile agent v3756e006
from typing import Dict, List, Optional, Tuple

import time
from pathlib import Path
import uuid

import copy
import json
from mm_agents.mobileagent_v3.mobile_agent_modules import (
    InfoPool,
    Manager,
    Executor,
    Grounding,
    Reflector
)

import dataclasses
from dataclasses import dataclass, field, asdict

@dataclasses.dataclass()
class JSONAction:
    action_type: Optional[str] = None
    action_code: Optional[str] = None
    x: Optional[int] = None
    y: Optional[int] = None
    text: Optional[str] = None
    clear: Optional[int] = None
    time: Optional[int] = None
    value: Optional[float] = None
    key_list: Optional[list] = None


def convert_xy(x, y):
    x_ = x * 1920 / 1932
    y_ = y * 1080 / 1092
    return x_, y_


def convert_fc_action_to_json_action_grounding(
        dummy_action, grounding_model, image_list, grounding_info=""
): # -> json_action.JSONAction:

        action_json = json.loads(dummy_action)
        action_type = action_json['action']
        
        x = None
        y = None
        text = None
        clear=None
        value=None
        time=None
        key_list = None
        action_code = ""

        if 'element_description' in action_json:
            [x, y], grounding_messages = grounding_model.predict(grounding_info+action_json['element_description'], image_list)
        elif "element1_description" in action_json:
            [x1, y1], grounding_messages1 = grounding_model.predict(grounding_info+action_json['element1_description'], image_list)
            [x2, y2], grounding_messages2 = grounding_model.predict(grounding_info+action_json['element2_description'], image_list)
            grounding_messages = [grounding_messages1, grounding_messages2]
        else:
            grounding_messages = None

        if action_type == 'click':
            x, y = convert_xy(x, y)
            action_code = f"import pyautogui; pyautogui.click(x={x}, y={y})"
        elif action_type == 'double_click':
            x, y = convert_xy(x, y)
            action_code = f"import pyautogui; pyautogui.doubleClick(x={x}, y={y})"
        elif action_type == 'right_click':
            x, y = convert_xy(x, y)
            action_code = f"import pyautogui; pyautogui.rightClick(x={x}, y={y})"
        elif action_type == 'type':
            x, y = convert_xy(x, y)
            text = action_json['text']
            if "\n" in text and "\\n" not in text:
                text = text.replace("\n", "\\n")
            clear = action_json['clear']
            enter = action_json['enter']
            action_code = f"import pyautogui; pyautogui.click(x={x}, y={y}); " 
            if clear > 0:
                action_code += "pyautogui.hotkey('ctrl', 'a'); pyautogui.press('delete'); "
            action_code += f"pyautogui.typewrite('{text}', interval=1.0)"
            if enter > 0:
                action_code += "; pyautogui.press('enter')"
        elif action_type == 'hotkey':
            key_list = action_json['keys']
            key_list_str = "'" + "', '".join(key_list) + "'"
            action_code = f"import pyautogui; pyautogui.hotkey({key_list_str})"
        elif action_type == 'scroll':
            x, y = convert_xy(x, y)
            value = action_json['value']
            action_code = f"import pyautogui; pyautogui.moveTo({x}, {y}); pyautogui.scroll({value})"
        elif action_type == 'wait':
            time = action_json['time']
            action_code = f"import time; time.sleep({time})"
        elif action_type == 'done':
            action_type = 'done'
            action_code = "DONE"
        
        elif action_type == 'drag':
            x1, y1 = convert_xy(x1, y1)
            x2, y2 = convert_xy(x2, y2)
            action_code = f"import pyautogui; "
            action_code += f"pyautogui.moveTo({x1}, {y1}); "
            action_code += f"pyautogui.dragTo({x2}, {y2}, duration=1.); pyautogui.mouseUp(); "
        
        return JSONAction(
                    action_type=action_type,
                    x=x,
                    y=y,
                    text=text,
                    clear=clear,
                    value=value,
                    time=time,
                    key_list=key_list,
                    action_code=action_code
                ), grounding_messages


def convert_fc_action_to_json_action(
        dummy_action
): # -> json_action.JSONAction:
 
        action_json = json.loads(dummy_action)
        action_type = action_json['action']
        
        x = None
        y = None
        text = None
        clear=None
        value=None
        time=None
        key_list = None
        action_code = ""

        if action_type == 'click':
            x, y = action_json['coordinate'][0], action_json['coordinate'][1]
            x, y = convert_xy(x, y)
            action_code = f"import pyautogui; pyautogui.click(x={x}, y={y})"
        elif action_type == 'double_click':
            x, y = action_json['coordinate'][0], action_json['coordinate'][1]
            x, y = convert_xy(x, y)
            action_code = f"import pyautogui; pyautogui.doubleClick(x={x}, y={y})"
        elif action_type == 'right_click':
            x, y = action_json['coordinate'][0], action_json['coordinate'][1]
            x, y = convert_xy(x, y)
            action_code = f"import pyautogui; pyautogui.rightClick(x={x}, y={y})"
        elif action_type == 'type':
            x, y = action_json['coordinate'][0], action_json['coordinate'][1]
            x, y = convert_xy(x, y)
            text = action_json['text']
            if "\n" in text and "\\n" not in text:
                text = text.replace("\n", "\\n")
            clear = action_json['clear']
            enter = action_json['enter']
            action_code = f"import pyautogui; pyautogui.click(x={x}, y={y}); " 
            if clear > 0:
                action_code += "pyautogui.hotkey('ctrl', 'a'); pyautogui.press('delete'); "
            action_code += f"pyautogui.typewrite('{text}', interval=1.0)"
            if enter > 0:
                action_code += "; pyautogui.press('enter')"
        elif action_type == 'hotkey':
            key_list = action_json['keys']
            key_list_str = "'" + "', '".join(key_list) + "'"
            action_code = f"import pyautogui; pyautogui.hotkey({key_list_str})"
        elif action_type == 'scroll':
            x, y = action_json['coordinate'][0], action_json['coordinate'][1]
            x, y = convert_xy(x, y)
            value = action_json['value']
            action_code = f"import pyautogui; pyautogui.moveTo({x}, {y}); pyautogui.scroll({value})"
        elif action_type == 'wait':
            time = action_json['time']
            action_code = f"import time; time.sleep({time})"
        elif action_type == 'done':
            action_type = 'done'
            action_code = "DONE"
        
        elif action_type == 'drag':
            x1, y1 = action_json['coordinate'][0], action_json['coordinate'][1]
            x1, y1 = convert_xy(x1, y1)
            x2, y2 = action_json['coordinate2'][0], action_json['coordinate2'][1]
            x2, y2 = convert_xy(x2, y2)
            action_code = f"import pyautogui; "
            action_code += f"pyautogui.moveTo({x1}, {y1}); "
            action_code += f"pyautogui.dragTo({x2}, {y2}, duration=1.); pyautogui.mouseUp(); "
        
        return JSONAction(
                    action_type=action_type,
                    x=x,
                    y=y,
                    text=text,
                    clear=clear,
                    value=value,
                    time=time,
                    key_list=key_list,
                    action_code=action_code
                )


INIT_TIPS = '''
General:
- If you see the "can't update chrome" popup, click the nearby X to close the prompt. Be sure not to click the "reinstall chrome" button.
- If you want to perform scroll action, 5 or -5 is an appropriate choice for `value` parameter.
- My computer's password is 'osworld-public-evaluation', feel free to use it when you need sudo rights.

Chrome:
- If the Chrome browser page is not maximized, you can use the alt+f10 shortcut to maximize the window, thereby displaying more information.
- If you cannot find the element you want to click on the current webpage, you can use the search function provided by the webpage (usually a search box or a magnifying glass icon), or directly search with appropriate keywords in the Google search engine.
'''

from datetime import datetime
import os


class MobileAgentV3:
    def __init__(
            self,
            manager_engine_params: Dict,
            operator_engine_params: Dict,
            reflector_engine_params: Dict,
            grounding_enging_params: Dict,
            wait_after_action_seconds: float = 3.0
    ):
        self.manager_engine_params = manager_engine_params
        self.operator_engine_params = operator_engine_params
        self.reflector_engine_params = reflector_engine_params
        self.grounding_enging_params = grounding_enging_params

        self.wait_after_action_seconds = wait_after_action_seconds
        
        # init info pool
        self.info_pool = InfoPool(
            additional_knowledge=copy.deepcopy(INIT_TIPS),
            err_to_manager_thresh=2
        )
        
        now = datetime.now()
        time_str = now.strftime("%Y%m%d_%H%M%S")


    def reset(self):
        self.info_pool = InfoPool(
            additional_knowledge=copy.deepcopy(INIT_TIPS),
            err_to_manager_thresh=2
        )

        now = datetime.now()
        time_str = now.strftime("%Y%m%d_%H%M%S")
        

    def step(self, instruction: str, env, args):
        ## init agents ## 
        manager = Manager(self.manager_engine_params)
        executor = Executor(self.operator_engine_params)
        reflector = Reflector(self.reflector_engine_params)
        
        global_state = {}
            
        message_manager, message_operator, message_reflector = None, None, None
        
        self.info_pool.instruction = instruction
        step_idx = len(self.info_pool.action_history)

        print('----------step ' + str(step_idx + 1))

        observation = env._get_obs()
        before_screenshot = observation['screenshot']

        self.info_pool.width = 1920
        self.info_pool.height = 1080

        ## check error escalation
        self.info_pool.error_flag_plan = False
        err_to_manager_thresh = self.info_pool.err_to_manager_thresh
        if len(self.info_pool.action_outcomes) >= err_to_manager_thresh:
            # check if the last err_to_manager_thresh actions are all errors
            latest_outcomes = self.info_pool.action_outcomes[-err_to_manager_thresh:]
            count = 0
            for outcome in latest_outcomes:
                if outcome in ["B", "C"]:
                    count += 1
            if count == err_to_manager_thresh:
                self.info_pool.error_flag_plan = True

        skip_manager = False
        ## if previous action is invalid, skip the manager and try again first ##
        if not self.info_pool.error_flag_plan and len(self.info_pool.action_history) > 0:
            if self.info_pool.action_history[-1]['action'] == 'invalid':
                skip_manager = True
            
        if not skip_manager:
            print("\n### Manager ... ###\n")

            rag_info = ""
            if args.enable_rag > 0:
                rag_dict = json.load(open(args.rag_path, 'r'))
                rag_info = rag_dict[instruction]

            guide = ""
            if args.guide_path != "":
                guide_dict = json.load(open(args.guide_path, 'r'))
                if instruction in guide_dict:
                    guide = guide_dict[instruction]

            planning_start_time = time.time()
            prompt_planning = manager.get_prompt(self.info_pool, args, rag_info, guide)
            output_planning, message_manager = manager.predict(prompt_planning, [before_screenshot])
            
            global_state['manager'] = {
                'name': 'manager',
                'messages': message_manager,
                'response': output_planning
            }
            
            parsed_result_planning = manager.parse_response(output_planning)
            self.info_pool.plan = parsed_result_planning['plan']
            self.info_pool.current_subgoal = parsed_result_planning['current_subgoal']
            planning_end_time = time.time()

            print('\n\nPlan: ' + self.info_pool.plan)
            print('Current subgoal: ' + self.info_pool.current_subgoal)
            print('Planning thought: ' + parsed_result_planning['thought'], "\n")

        ## if stopping by planner ##
        if "Finished" in self.info_pool.current_subgoal.strip():
            self.info_pool.finish_thought = parsed_result_planning['thought']
            action_thought = "Finished by planner"
            action_object_str = "{\"action\": \"done\"}"
            action_description = "Finished by planner"
        
        else:
            print("\n### Operator ... ###\n")
            action_decision_start_time = time.time()
            prompt_action = executor.get_prompt(self.info_pool, args.grounding_stage)
            output_action, message_operator = executor.predict(prompt_action, [before_screenshot])

            parsed_result_action = executor.parse_response(output_action)
            action_thought, action_object_str, action_description = parsed_result_action['thought'], parsed_result_action['action'], parsed_result_action['description']
            action_decision_end_time = time.time()

            action_object_str = action_object_str.split('```json')[-1].split('```')[0]
            self.info_pool.last_action_thought = action_thought
            self.info_pool.last_summary = action_description

            # If the output is not in the right format, add it to step summary which
            # will be passed to next step and return.
            if (not action_thought) or (not action_object_str):
                print('Action prompt output is not in the correct format.')
                self.info_pool.last_action = {"action": "invalid"}
                self.info_pool.action_history.append({"action": "invalid"})
                self.info_pool.summary_history.append(action_description)
                self.info_pool.action_outcomes.append("C") # no change
                self.info_pool.error_descriptions.append("invalid action format, do nothing.")
                return global_state, None, False, None, False
        
        print('\n\nThought: ' + action_thought)
        print('Action: ' + action_object_str)
        print('Action description: ' + action_description, '\n\n')

        format_action_object_str = action_object_str

        operator_response = f'''### Thought ###
{action_thought}

### Action ###
{format_action_object_str}

### Description ###
{action_description}'''
        global_state['operator'] = {
            'name': 'operator',
            'messages': message_operator,
            'response': operator_response
        }

        try:
            if args.grounding_stage > 0:
                grouding_model = Grounding(self.grounding_enging_params)
                if args.grounding_info_level == 0:
                    converted_action, grounding_messages = convert_fc_action_to_json_action_grounding(action_object_str, grouding_model, [before_screenshot])
                elif args.grounding_info_level == 1:
                    grounding_info = "Thought: " + action_thought + "\nElement description: "
                    converted_action, grounding_messages = convert_fc_action_to_json_action_grounding(action_object_str, grouding_model, [before_screenshot], grounding_info)
                if grounding_messages is not None:
                    global_state['grounding'] = {
                        'name': 'grounding',
                        'messages': grounding_messages,
                    }
            else:
                converted_action = convert_fc_action_to_json_action(action_object_str)

        except Exception as e:
            print('Failed to convert the output to a valid action.')
            print(str(e))
            self.info_pool.last_action = {"action": "invalid"}
            self.info_pool.action_history.append({"action": "invalid"})
            self.info_pool.summary_history.append(action_description)
            self.info_pool.action_outcomes.append("C") # no change
            self.info_pool.error_descriptions.append("invalid action format, do nothing.")
            return global_state, action_object_str, False, None, False

        if converted_action.action_type == 'done':
            outcome = "A"
            error_description = "None"

            self.info_pool.last_action = json.loads(action_object_str)
            self.info_pool.action_history.append(json.loads(action_object_str))
            self.info_pool.summary_history.append(action_description)
            self.info_pool.action_outcomes.append(outcome) # no change
            self.info_pool.error_descriptions.append(error_description)
            return global_state, converted_action.action_code, True, None, True
            
        try:
            if len(self.info_pool.action_history) >= args.max_trajectory_length-1:
                converted_action.action_code = 'FAIL'
            obs, env_reward, env_done, env_info = env.step(converted_action.action_code, self.wait_after_action_seconds)

        except Exception as e:
            print('Failed to execute action.')
            print(str(e))
            self.info_pool.last_action = json.loads({"action": "invalid"})
            self.info_pool.action_history.append({"action": "invalid"})
            self.info_pool.summary_history.append(action_description)
            self.info_pool.action_outcomes.append("C") # no change
            self.info_pool.error_descriptions.append(f"Failed to execute the action: {converted_action}")
            return global_state, converted_action.action_code, False, None, False

        print("Done action execution.\n")
        self.info_pool.last_action = json.loads(action_object_str)

        after_screenshot = obs['screenshot']

        print("\n### Reflector ... ###\n")
        if converted_action.action_type != 'answer':
            action_reflection_start_time = time.time()
            prompt_action_reflect = reflector.get_prompt(self.info_pool)
            output_action_reflect, message_reflector = reflector.predict(prompt_action_reflect, [before_screenshot, after_screenshot])

            global_state['reflector'] = {
                'name': 'reflector',
                'messages': message_reflector,
                'response': output_action_reflect
            }
            
            parsed_result_action_reflect = reflector.parse_response(output_action_reflect)
            outcome, error_description, progress_status = (
                    parsed_result_action_reflect['outcome'], 
                    parsed_result_action_reflect['error_description'], 
                    parsed_result_action_reflect['progress_status']
            )
            action_reflection_end_time = time.time()

            if "A" in outcome: # Successful. The result of the last action meets the expectation.
                action_outcome = "A"
            elif "B" in outcome: # Failed. The last action results in a wrong page. I need to return to the previous state.
                action_outcome = "B"
            elif "C" in outcome: # Failed. The last action produces no changes.
                action_outcome = "C"
            else:
                raise ValueError("Invalid outcome:", outcome)
        
        print('\n\nAction reflection outcome: ' + action_outcome)
        print('Action reflection error description: ' + error_description)
        print('Action reflection progress status: ' + progress_status, "\n")

        self.info_pool.action_history.append(json.loads(action_object_str))
        self.info_pool.summary_history.append(action_description)
        self.info_pool.action_outcomes.append(action_outcome)
        self.info_pool.error_descriptions.append(error_description)
        self.info_pool.progress_status = progress_status
        self.info_pool.progress_status_history.append(progress_status)

        return global_state, converted_action.action_code, True, env_reward, env_done