From 1e933b703f0a45af7aab04a4958336d3dd312d02 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Tue, 19 Mar 2024 09:12:28 -0700 Subject: [PATCH] AsyncOperation: support for running asynchronous jobs while the agent is running (#111) --- skyvern/config.py | 2 + skyvern/forge/agent.py | 19 +++++ skyvern/forge/app.py | 8 ++ skyvern/forge/async_operations.py | 132 ++++++++++++++++++++++++++++++ skyvern/webeye/browser_factory.py | 3 +- 5 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 skyvern/forge/async_operations.py diff --git a/skyvern/config.py b/skyvern/config.py index 0c346c68..8255e3c4 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -46,6 +46,8 @@ class Settings(BaseSettings): # browser settings BROWSER_LOCALE: str = "en-US" BROWSER_TIMEZONE: str = "America/New_York" + BROWSER_WIDTH: int = 1920 + BROWSER_HEIGHT: int = 1080 ##################### # LLM Configuration # diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index f699e667..c125b743 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -7,6 +7,7 @@ from typing import Any, Tuple import requests import structlog from playwright._impl._errors import TargetClosedError +from playwright.async_api import Page from skyvern import analytics from skyvern.exceptions import ( @@ -17,6 +18,7 @@ from skyvern.exceptions import ( TaskNotFound, ) from skyvern.forge import app +from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.agent import Agent from skyvern.forge.sdk.artifact.models import ArtifactType @@ -64,6 +66,7 @@ class ForgeAgent(Agent): long_running_task_warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO, debug_mode=SettingsManager.get_settings().DEBUG_MODE, ) + self.async_operation_pool = AsyncOperationPool() async def validate_step_execution( self, @@ -193,6 +196,12 @@ class ForgeAgent(Agent): ) return task + def register_async_operations(self, organization: Organization, task: Task, page: Page) -> None: + if not app.generate_async_operations: + return + operations = app.generate_async_operations(organization, task, page) + self.async_operation_pool.add_operations(task.task_id, operations) + async def execute_step( self, organization: Organization, @@ -208,6 +217,10 @@ class ForgeAgent(Agent): # Check some conditions before executing the step, throw an exception if the step can't be executed await self.validate_step_execution(task, step) step, browser_state, detailed_output = await self._initialize_execution_state(task, step, workflow_run) + + if browser_state.page: + self.register_async_operations(organization, task, browser_state.page) + step, detailed_output = await self.agent_step(task, step, browser_state, organization=organization) task = await self.update_task_errors_from_detailed_output(task, detailed_output) retry = False @@ -226,6 +239,7 @@ class ForgeAgent(Agent): api_key=api_key, close_browser_on_completion=close_browser_on_completion, ) + await self.async_operation_pool.remove_task(task.task_id) return step, detailed_output, None elif step.status == StepStatus.completed: # TODO (kerem): keep the task object uptodate at all times so that send_task_response can just use it @@ -332,6 +346,7 @@ class ForgeAgent(Agent): json_response = None actions: list[Action] if task.navigation_goal: + self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm) json_response = await app.LLM_API_HANDLER( prompt=extract_action_prompt, step=step, @@ -403,6 +418,7 @@ class ForgeAgent(Agent): break web_action_element_ids.add(action.element_id) + self.async_operation_pool.run_operation(task.task_id, AgentPhase.action) results = await ActionHandler.handle_action(scraped_page, task, step, browser_state, action) detailed_agent_step_output.actions_and_results[action_idx] = (action, results) # wait random time between actions to avoid detection @@ -559,6 +575,9 @@ class ForgeAgent(Agent): step: Step, browser_state: BrowserState, ) -> tuple[ScrapedPage, str]: + # start the async tasks while running scrape_website + self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape) + # Scrape the web page and get the screenshot and the elements scraped_page = await scrape_website( browser_state, diff --git a/skyvern/forge/app.py b/skyvern/forge/app.py index 37bf084b..7cec463f 100644 --- a/skyvern/forge/app.py +++ b/skyvern/forge/app.py @@ -1,12 +1,18 @@ +from typing import Callable + from ddtrace import tracer from ddtrace.filters import FilterRequestsOnUrl +from playwright.async_api import Page from skyvern.forge.agent import ForgeAgent +from skyvern.forge.async_operations import AsyncOperation from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.artifact.manager import ArtifactManager from skyvern.forge.sdk.artifact.storage.factory import StorageFactory from skyvern.forge.sdk.db.client import AgentDB from skyvern.forge.sdk.forge_log import setup_logger +from skyvern.forge.sdk.models import Organization +from skyvern.forge.sdk.schemas.tasks import Task from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager from skyvern.forge.sdk.workflow.service import WorkflowService @@ -31,6 +37,8 @@ BROWSER_MANAGER = BrowserManager() LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY) WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager() WORKFLOW_SERVICE = WorkflowService() +generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None + agent = ForgeAgent() app = agent.get_agent_app() diff --git a/skyvern/forge/async_operations.py b/skyvern/forge/async_operations.py new file mode 100644 index 00000000..8f30098f --- /dev/null +++ b/skyvern/forge/async_operations.py @@ -0,0 +1,132 @@ +import asyncio +from enum import StrEnum + +import structlog +from playwright.async_api import Page + +LOG = structlog.get_logger() + + +class AgentPhase(StrEnum): + """ + Phase of agent when async execution events are happening + """ + + action = "action" + scrape = "scrape" + llm = "llm" + + +VALID_AGENT_PHASES = [phase.value for phase in AgentPhase] + + +class AsyncOperation: + """ + AsyncOperation can take async actions on the page while agent is performing the task. + + Examples: + - collect info based on the html/DOM and send data to your server + """ + + def __init__(self, task_id: str, operation_type: str, agent_phase: AgentPhase, page: Page) -> None: + """ + :param task_id: task_id of the task + :param operation_type: it's the custom type of the operation. + there will only be up to one aio task running per operation_type + :param agent_phase: AgentPhase type. phase of the agent when the operation is running + :param page: playwright page for the task + """ + self.task_id = task_id + self.type = operation_type + self.agent_phase = agent_phase + self.aio_task: asyncio.Task | None = None + + # playwright page could be used by the operation to take actions + self.page = page + + async def execute(self) -> None: + return + + def run(self) -> asyncio.Task | None: + if self.aio_task is not None and not self.aio_task.done(): + LOG.warning( + f"Task already running", + task_id=self.task_id, + operation_type=self.type, + agent_phase=self.agent_phase, + ) + return None + self.aio_task = asyncio.create_task(self.execute()) + return self.aio_task + + +class AsyncOperationPool: + _operations: dict[str, dict[AgentPhase, AsyncOperation]] = {} # task_id: {agent_phase: operation} + + # use _aio_tasks to ensure we're only execution one aio task for the same operation_type + _aio_tasks: dict[str, dict[str, asyncio.Task]] = {} # task_id: {operation_type: aio_task} + + def _add_operation(self, task_id: str, operation: AsyncOperation) -> None: + if operation.agent_phase not in VALID_AGENT_PHASES: + raise ValueError(f"operation's agent phase {operation.agent_phase} is not valid") + if task_id not in self._operations: + self._operations[task_id] = {} + self._operations[task_id][operation.agent_phase] = operation + + def add_operations(self, task_id: str, operations: list[AsyncOperation]) -> None: + if task_id in self._operations: + # already exists + return + for operation in operations: + self._add_operation(task_id, operation) + + def _get_operation(self, task_id: str, operation_type: AgentPhase) -> AsyncOperation | None: + return self._operations.get(task_id, {}).get(operation_type, None) + + def remove_operations(self, task_id: str) -> None: + if task_id in self._operations: + del self._operations[task_id] + + def get_aio_tasks(self, task_id: str) -> list[asyncio.Task]: + """ + Get all the running/pending aio tasks for the given task_id + """ + return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if not aio_task.done()] + + def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None: + # get the operation from the pool + operation = self._get_operation(task_id, agent_phase) + if operation is None: + return + + # if found, initialize the operation if it's the first time running the aio task + operation_type = operation.type + if task_id not in self._aio_tasks: + self._aio_tasks[task_id] = {} + + # if the aio task is already running, don't run it again + aio_task: asyncio.Task | None = None + if operation_type in self._aio_tasks[task_id]: + aio_task = self._aio_tasks[task_id][operation_type] + if not aio_task.done(): + LOG.info( + f"aio task already running", + task_id=task_id, + operation_type=operation_type, + agent_phase=agent_phase, + ) + return + + # run the operation if the aio task is not running + aio_task = operation.run() + if aio_task: + self._aio_tasks[task_id][operation_type] = aio_task + + async def remove_task(self, task_id: str) -> None: + try: + async with asyncio.timeout(30): + await asyncio.gather(*[aio_task for aio_task in self.get_aio_tasks(task_id) if not aio_task.done()]) + except asyncio.TimeoutError: + LOG.error(f"Timeout (30s) while waiting for pending async tasks for task_id={task_id}", task_id=task_id) + + self.remove_operations(task_id) diff --git a/skyvern/webeye/browser_factory.py b/skyvern/webeye/browser_factory.py index ddf549e6..00ee22d0 100644 --- a/skyvern/webeye/browser_factory.py +++ b/skyvern/webeye/browser_factory.py @@ -11,6 +11,7 @@ from playwright._impl._errors import TimeoutError from playwright.async_api import BrowserContext, Error, Page, Playwright, async_playwright from pydantic import BaseModel +from skyvern.config import settings from skyvern.exceptions import ( FailedToNavigateToUrl, FailedToTakeScreenshot, @@ -61,7 +62,7 @@ class BrowserContextFactory: ], "record_har_path": har_dir, "record_video_dir": video_dir, - "viewport": {"width": 1920, "height": 1080}, + "viewport": {"width": settings.BROWSER_WIDTH, "height": settings.BROWSER_HEIGHT}, } @staticmethod