diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 1eb9e8b2..adef5b61 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -337,7 +337,6 @@ class ForgeAgent: api_key=api_key, close_browser_on_completion=close_browser_on_completion, ) - await self.async_operation_pool.remove_task(task.task_id) return step, detailed_output, None elif step.status == StepStatus.completed: # TODO (kerem): keep the task object uptodate at all times so that clean_up_task can just use it @@ -1402,6 +1401,7 @@ class ForgeAgent: ) return + await self.async_operation_pool.remove_task(task.task_id) await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task) # Wait for all tasks to complete before generating the links for the artifacts diff --git a/skyvern/forge/async_operations.py b/skyvern/forge/async_operations.py index 9648f460..db825887 100644 --- a/skyvern/forge/async_operations.py +++ b/skyvern/forge/async_operations.py @@ -4,6 +4,8 @@ from enum import StrEnum import structlog from playwright.async_api import Page +from skyvern.forge.sdk.core.asyncio_helper import is_aio_task_running + LOG = structlog.get_logger() @@ -48,7 +50,7 @@ class AsyncOperation: return def run(self) -> asyncio.Task | None: - if self.aio_task is not None and not self.aio_task.done(): + if self.aio_task is not None and is_aio_task_running(self.aio_task): LOG.warning( "Task already running", task_id=self.task_id, @@ -80,10 +82,10 @@ class AsyncOperationPool: for operation in operations: self._add_operation(task_id, operation) - def _get_operation(self, task_id: str, operation_type: AgentPhase) -> AsyncOperation | None: - return self._operations.get(task_id, {}).get(operation_type, None) + def _get_operation(self, task_id: str, agent_phase: AgentPhase) -> AsyncOperation | None: + return self._operations.get(task_id, {}).get(agent_phase, None) - def remove_operations(self, task_id: str) -> None: + def _remove_operations(self, task_id: str) -> None: if task_id in self._operations: del self._operations[task_id] @@ -91,20 +93,37 @@ class AsyncOperationPool: """ Get all the running/pending aio tasks for the given task_id """ - return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if not aio_task.done()] + return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if is_aio_task_running(aio_task)] def get_aio_task(self, task_id: str, operation_type: str) -> asyncio.Task | None: return self._aio_tasks.get(task_id, {}).get(operation_type, None) - async def wait_for_task(self, task_id: str, operation_type: str, timeout: float | None = None) -> None: + def _remove_aio_tasks(self, task_id: str) -> None: + if task_id in self._aio_tasks: + del self._aio_tasks[task_id] + + async def wait_for_task( + self, + task_id: str, + operation_type: str, + timeout: float | None = 5, + ) -> None: running_task = self.get_aio_task(task_id=task_id, operation_type=operation_type) - if running_task and not running_task.done(): + if running_task is None or not is_aio_task_running(running_task): + return + LOG.info( + "wait for the running aio task to be done", + task_id=task_id, + operation_type=operation_type, + ) + try: + await asyncio.wait_for(running_task, timeout) + except TimeoutError: LOG.info( - "wait for the running aio task to be done", + f"Timeout ({timeout}s) while waiting for the running aio task to be done", task_id=task_id, operation_type=operation_type, ) - await asyncio.wait_for(running_task, timeout) def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None: # get the operation from the pool @@ -121,7 +140,7 @@ class AsyncOperationPool: aio_task: asyncio.Task | None = None if operation_type in self._aio_tasks[task_id]: aio_task = self._aio_tasks[task_id][operation_type] - if not aio_task.done(): + if is_aio_task_running(aio_task): LOG.info( "aio task already running", task_id=task_id, @@ -138,11 +157,15 @@ class AsyncOperationPool: async def remove_task(self, task_id: str) -> None: try: async with asyncio.timeout(30): - await asyncio.gather(*[aio_task for aio_task in self.get_aio_tasks(task_id) if not aio_task.done()]) - except asyncio.TimeoutError: + await asyncio.gather( + *[aio_task for aio_task in self.get_aio_tasks(task_id) if is_aio_task_running(aio_task)] + ) + except TimeoutError: LOG.error( f"Timeout (30s) while waiting for pending async tasks for task_id={task_id}", task_id=task_id, ) - self.remove_operations(task_id) + self._remove_aio_tasks(task_id) + self._remove_operations(task_id) + LOG.info("Successfully removed aio tasks and async operations", task_id=task_id) diff --git a/skyvern/forge/sdk/core/asyncio_helper.py b/skyvern/forge/sdk/core/asyncio_helper.py new file mode 100644 index 00000000..1aafe86a --- /dev/null +++ b/skyvern/forge/sdk/core/asyncio_helper.py @@ -0,0 +1,5 @@ +import asyncio + + +def is_aio_task_running(aio_task: asyncio.Task) -> bool: + return not aio_task.done() and not aio_task.cancelled()