mirrored 7 minutes ago
0
Howieadd support for mobile agent v3 (#328) * add support for mobile agent v3 * add mobile_agent * add support for mobile agent v3756e006
import datetime
import json
import logging
import os
import time
from wrapt_timeout_decorator import *

logger = logging.getLogger("desktopenv.experiment")


def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
    runtime_logger = setup_logger(example, example_result_dir)
    try:
        agent.reset(runtime_logger)
    except:
        agent.reset()
    env.reset(task_config=example)
    time.sleep(60) # Wait for the environment to be ready
    obs = env._get_obs() # Get the initial observation

    done = False
    step_idx = 0
    
    # save the first step
    action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
    with open(os.path.join(example_result_dir, f"step_{step_idx}_{action_timestamp}.png"), "wb") as _f:
        _f.write(obs['screenshot'])
    
    eval_flag = True
    env.controller.start_recording()
    while not done and step_idx < max_steps:
        global_state, action_code, step_status, reward, done = agent.step(instruction, env, args)
        action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")

        if step_status is False:
            eval_flag = False
            done = True
            reward = None
        else:
            obs = env._get_obs()
            with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
                        "wb") as _f:
                _f.write(obs['screenshot'])
        with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
            f.write(json.dumps({
                "step_num": step_idx + 1,
                "step_status": step_status,
                "action_timestamp": action_timestamp,
                "action": action_code,
                "reward": reward,
                "done": done,
                "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png",
            }))
            f.write("\n")
        if done:    
            logger.info("The episode is done.")
            break
        step_idx += 1

    if eval_flag:
        result = env.evaluate()
        logger.info("Result: %.2f", result)
        scores.append(result)
        with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
            f.write(f"{result}\n")
    env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))


def setup_logger(example, example_result_dir):
    runtime_logger = logging.getLogger(f"desktopenv.example.{example['id']}")
    runtime_logger.setLevel(logging.DEBUG)
    runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log")))
    return runtime_logger