AsyncOperation: support for running asynchronous jobs while the agent is running (#111)
This commit is contained in:
@@ -46,6 +46,8 @@ class Settings(BaseSettings):
|
|||||||
# browser settings
|
# browser settings
|
||||||
BROWSER_LOCALE: str = "en-US"
|
BROWSER_LOCALE: str = "en-US"
|
||||||
BROWSER_TIMEZONE: str = "America/New_York"
|
BROWSER_TIMEZONE: str = "America/New_York"
|
||||||
|
BROWSER_WIDTH: int = 1920
|
||||||
|
BROWSER_HEIGHT: int = 1080
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# LLM Configuration #
|
# LLM Configuration #
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Any, Tuple
|
|||||||
import requests
|
import requests
|
||||||
import structlog
|
import structlog
|
||||||
from playwright._impl._errors import TargetClosedError
|
from playwright._impl._errors import TargetClosedError
|
||||||
|
from playwright.async_api import Page
|
||||||
|
|
||||||
from skyvern import analytics
|
from skyvern import analytics
|
||||||
from skyvern.exceptions import (
|
from skyvern.exceptions import (
|
||||||
@@ -17,6 +18,7 @@ from skyvern.exceptions import (
|
|||||||
TaskNotFound,
|
TaskNotFound,
|
||||||
)
|
)
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool
|
||||||
from skyvern.forge.prompts import prompt_engine
|
from skyvern.forge.prompts import prompt_engine
|
||||||
from skyvern.forge.sdk.agent import Agent
|
from skyvern.forge.sdk.agent import Agent
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
@@ -64,6 +66,7 @@ class ForgeAgent(Agent):
|
|||||||
long_running_task_warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
|
long_running_task_warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
|
||||||
debug_mode=SettingsManager.get_settings().DEBUG_MODE,
|
debug_mode=SettingsManager.get_settings().DEBUG_MODE,
|
||||||
)
|
)
|
||||||
|
self.async_operation_pool = AsyncOperationPool()
|
||||||
|
|
||||||
async def validate_step_execution(
|
async def validate_step_execution(
|
||||||
self,
|
self,
|
||||||
@@ -193,6 +196,12 @@ class ForgeAgent(Agent):
|
|||||||
)
|
)
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
def register_async_operations(self, organization: Organization, task: Task, page: Page) -> None:
|
||||||
|
if not app.generate_async_operations:
|
||||||
|
return
|
||||||
|
operations = app.generate_async_operations(organization, task, page)
|
||||||
|
self.async_operation_pool.add_operations(task.task_id, operations)
|
||||||
|
|
||||||
async def execute_step(
|
async def execute_step(
|
||||||
self,
|
self,
|
||||||
organization: Organization,
|
organization: Organization,
|
||||||
@@ -208,6 +217,10 @@ class ForgeAgent(Agent):
|
|||||||
# Check some conditions before executing the step, throw an exception if the step can't be executed
|
# Check some conditions before executing the step, throw an exception if the step can't be executed
|
||||||
await self.validate_step_execution(task, step)
|
await self.validate_step_execution(task, step)
|
||||||
step, browser_state, detailed_output = await self._initialize_execution_state(task, step, workflow_run)
|
step, browser_state, detailed_output = await self._initialize_execution_state(task, step, workflow_run)
|
||||||
|
|
||||||
|
if browser_state.page:
|
||||||
|
self.register_async_operations(organization, task, browser_state.page)
|
||||||
|
|
||||||
step, detailed_output = await self.agent_step(task, step, browser_state, organization=organization)
|
step, detailed_output = await self.agent_step(task, step, browser_state, organization=organization)
|
||||||
task = await self.update_task_errors_from_detailed_output(task, detailed_output)
|
task = await self.update_task_errors_from_detailed_output(task, detailed_output)
|
||||||
retry = False
|
retry = False
|
||||||
@@ -226,6 +239,7 @@ class ForgeAgent(Agent):
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
close_browser_on_completion=close_browser_on_completion,
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
)
|
)
|
||||||
|
await self.async_operation_pool.remove_task(task.task_id)
|
||||||
return step, detailed_output, None
|
return step, detailed_output, None
|
||||||
elif step.status == StepStatus.completed:
|
elif step.status == StepStatus.completed:
|
||||||
# TODO (kerem): keep the task object uptodate at all times so that send_task_response can just use it
|
# TODO (kerem): keep the task object uptodate at all times so that send_task_response can just use it
|
||||||
@@ -332,6 +346,7 @@ class ForgeAgent(Agent):
|
|||||||
json_response = None
|
json_response = None
|
||||||
actions: list[Action]
|
actions: list[Action]
|
||||||
if task.navigation_goal:
|
if task.navigation_goal:
|
||||||
|
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
||||||
json_response = await app.LLM_API_HANDLER(
|
json_response = await app.LLM_API_HANDLER(
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
step=step,
|
step=step,
|
||||||
@@ -403,6 +418,7 @@ class ForgeAgent(Agent):
|
|||||||
break
|
break
|
||||||
web_action_element_ids.add(action.element_id)
|
web_action_element_ids.add(action.element_id)
|
||||||
|
|
||||||
|
self.async_operation_pool.run_operation(task.task_id, AgentPhase.action)
|
||||||
results = await ActionHandler.handle_action(scraped_page, task, step, browser_state, action)
|
results = await ActionHandler.handle_action(scraped_page, task, step, browser_state, action)
|
||||||
detailed_agent_step_output.actions_and_results[action_idx] = (action, results)
|
detailed_agent_step_output.actions_and_results[action_idx] = (action, results)
|
||||||
# wait random time between actions to avoid detection
|
# wait random time between actions to avoid detection
|
||||||
@@ -559,6 +575,9 @@ class ForgeAgent(Agent):
|
|||||||
step: Step,
|
step: Step,
|
||||||
browser_state: BrowserState,
|
browser_state: BrowserState,
|
||||||
) -> tuple[ScrapedPage, str]:
|
) -> tuple[ScrapedPage, str]:
|
||||||
|
# start the async tasks while running scrape_website
|
||||||
|
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
|
||||||
|
|
||||||
# Scrape the web page and get the screenshot and the elements
|
# Scrape the web page and get the screenshot and the elements
|
||||||
scraped_page = await scrape_website(
|
scraped_page = await scrape_website(
|
||||||
browser_state,
|
browser_state,
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
from ddtrace import tracer
|
from ddtrace import tracer
|
||||||
from ddtrace.filters import FilterRequestsOnUrl
|
from ddtrace.filters import FilterRequestsOnUrl
|
||||||
|
from playwright.async_api import Page
|
||||||
|
|
||||||
from skyvern.forge.agent import ForgeAgent
|
from skyvern.forge.agent import ForgeAgent
|
||||||
|
from skyvern.forge.async_operations import AsyncOperation
|
||||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||||
from skyvern.forge.sdk.db.client import AgentDB
|
from skyvern.forge.sdk.db.client import AgentDB
|
||||||
from skyvern.forge.sdk.forge_log import setup_logger
|
from skyvern.forge.sdk.forge_log import setup_logger
|
||||||
|
from skyvern.forge.sdk.models import Organization
|
||||||
|
from skyvern.forge.sdk.schemas.tasks import Task
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
|
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
|
||||||
from skyvern.forge.sdk.workflow.service import WorkflowService
|
from skyvern.forge.sdk.workflow.service import WorkflowService
|
||||||
@@ -31,6 +37,8 @@ BROWSER_MANAGER = BrowserManager()
|
|||||||
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
|
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
|
||||||
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||||
WORKFLOW_SERVICE = WorkflowService()
|
WORKFLOW_SERVICE = WorkflowService()
|
||||||
|
generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None
|
||||||
|
|
||||||
agent = ForgeAgent()
|
agent = ForgeAgent()
|
||||||
|
|
||||||
app = agent.get_agent_app()
|
app = agent.get_agent_app()
|
||||||
|
|||||||
132
skyvern/forge/async_operations.py
Normal file
132
skyvern/forge/async_operations.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
import asyncio
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from playwright.async_api import Page
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class AgentPhase(StrEnum):
|
||||||
|
"""
|
||||||
|
Phase of agent when async execution events are happening
|
||||||
|
"""
|
||||||
|
|
||||||
|
action = "action"
|
||||||
|
scrape = "scrape"
|
||||||
|
llm = "llm"
|
||||||
|
|
||||||
|
|
||||||
|
VALID_AGENT_PHASES = [phase.value for phase in AgentPhase]
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncOperation:
|
||||||
|
"""
|
||||||
|
AsyncOperation can take async actions on the page while agent is performing the task.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- collect info based on the html/DOM and send data to your server
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_id: str, operation_type: str, agent_phase: AgentPhase, page: Page) -> None:
|
||||||
|
"""
|
||||||
|
:param task_id: task_id of the task
|
||||||
|
:param operation_type: it's the custom type of the operation.
|
||||||
|
there will only be up to one aio task running per operation_type
|
||||||
|
:param agent_phase: AgentPhase type. phase of the agent when the operation is running
|
||||||
|
:param page: playwright page for the task
|
||||||
|
"""
|
||||||
|
self.task_id = task_id
|
||||||
|
self.type = operation_type
|
||||||
|
self.agent_phase = agent_phase
|
||||||
|
self.aio_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
# playwright page could be used by the operation to take actions
|
||||||
|
self.page = page
|
||||||
|
|
||||||
|
async def execute(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def run(self) -> asyncio.Task | None:
|
||||||
|
if self.aio_task is not None and not self.aio_task.done():
|
||||||
|
LOG.warning(
|
||||||
|
f"Task already running",
|
||||||
|
task_id=self.task_id,
|
||||||
|
operation_type=self.type,
|
||||||
|
agent_phase=self.agent_phase,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
self.aio_task = asyncio.create_task(self.execute())
|
||||||
|
return self.aio_task
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncOperationPool:
|
||||||
|
_operations: dict[str, dict[AgentPhase, AsyncOperation]] = {} # task_id: {agent_phase: operation}
|
||||||
|
|
||||||
|
# use _aio_tasks to ensure we're only execution one aio task for the same operation_type
|
||||||
|
_aio_tasks: dict[str, dict[str, asyncio.Task]] = {} # task_id: {operation_type: aio_task}
|
||||||
|
|
||||||
|
def _add_operation(self, task_id: str, operation: AsyncOperation) -> None:
|
||||||
|
if operation.agent_phase not in VALID_AGENT_PHASES:
|
||||||
|
raise ValueError(f"operation's agent phase {operation.agent_phase} is not valid")
|
||||||
|
if task_id not in self._operations:
|
||||||
|
self._operations[task_id] = {}
|
||||||
|
self._operations[task_id][operation.agent_phase] = operation
|
||||||
|
|
||||||
|
def add_operations(self, task_id: str, operations: list[AsyncOperation]) -> None:
|
||||||
|
if task_id in self._operations:
|
||||||
|
# already exists
|
||||||
|
return
|
||||||
|
for operation in operations:
|
||||||
|
self._add_operation(task_id, operation)
|
||||||
|
|
||||||
|
def _get_operation(self, task_id: str, operation_type: AgentPhase) -> AsyncOperation | None:
|
||||||
|
return self._operations.get(task_id, {}).get(operation_type, None)
|
||||||
|
|
||||||
|
def remove_operations(self, task_id: str) -> None:
|
||||||
|
if task_id in self._operations:
|
||||||
|
del self._operations[task_id]
|
||||||
|
|
||||||
|
def get_aio_tasks(self, task_id: str) -> list[asyncio.Task]:
|
||||||
|
"""
|
||||||
|
Get all the running/pending aio tasks for the given task_id
|
||||||
|
"""
|
||||||
|
return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if not aio_task.done()]
|
||||||
|
|
||||||
|
def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None:
|
||||||
|
# get the operation from the pool
|
||||||
|
operation = self._get_operation(task_id, agent_phase)
|
||||||
|
if operation is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# if found, initialize the operation if it's the first time running the aio task
|
||||||
|
operation_type = operation.type
|
||||||
|
if task_id not in self._aio_tasks:
|
||||||
|
self._aio_tasks[task_id] = {}
|
||||||
|
|
||||||
|
# if the aio task is already running, don't run it again
|
||||||
|
aio_task: asyncio.Task | None = None
|
||||||
|
if operation_type in self._aio_tasks[task_id]:
|
||||||
|
aio_task = self._aio_tasks[task_id][operation_type]
|
||||||
|
if not aio_task.done():
|
||||||
|
LOG.info(
|
||||||
|
f"aio task already running",
|
||||||
|
task_id=task_id,
|
||||||
|
operation_type=operation_type,
|
||||||
|
agent_phase=agent_phase,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# run the operation if the aio task is not running
|
||||||
|
aio_task = operation.run()
|
||||||
|
if aio_task:
|
||||||
|
self._aio_tasks[task_id][operation_type] = aio_task
|
||||||
|
|
||||||
|
async def remove_task(self, task_id: str) -> None:
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(30):
|
||||||
|
await asyncio.gather(*[aio_task for aio_task in self.get_aio_tasks(task_id) if not aio_task.done()])
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
LOG.error(f"Timeout (30s) while waiting for pending async tasks for task_id={task_id}", task_id=task_id)
|
||||||
|
|
||||||
|
self.remove_operations(task_id)
|
||||||
@@ -11,6 +11,7 @@ from playwright._impl._errors import TimeoutError
|
|||||||
from playwright.async_api import BrowserContext, Error, Page, Playwright, async_playwright
|
from playwright.async_api import BrowserContext, Error, Page, Playwright, async_playwright
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
from skyvern.exceptions import (
|
from skyvern.exceptions import (
|
||||||
FailedToNavigateToUrl,
|
FailedToNavigateToUrl,
|
||||||
FailedToTakeScreenshot,
|
FailedToTakeScreenshot,
|
||||||
@@ -61,7 +62,7 @@ class BrowserContextFactory:
|
|||||||
],
|
],
|
||||||
"record_har_path": har_dir,
|
"record_har_path": har_dir,
|
||||||
"record_video_dir": video_dir,
|
"record_video_dir": video_dir,
|
||||||
"viewport": {"width": 1920, "height": 1080},
|
"viewport": {"width": settings.BROWSER_WIDTH, "height": settings.BROWSER_HEIGHT},
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user