diff --git a/skyvern/forge/async_operations.py b/skyvern/forge/async_operations.py index 05c41bbb..9648f460 100644 --- a/skyvern/forge/async_operations.py +++ b/skyvern/forge/async_operations.py @@ -96,6 +96,16 @@ class AsyncOperationPool: def get_aio_task(self, task_id: str, operation_type: str) -> asyncio.Task | None: return self._aio_tasks.get(task_id, {}).get(operation_type, None) + async def wait_for_task(self, task_id: str, operation_type: str, timeout: float | None = None) -> None: + running_task = self.get_aio_task(task_id=task_id, operation_type=operation_type) + if running_task and not running_task.done(): + LOG.info( + "wait for the running aio task to be done", + task_id=task_id, + operation_type=operation_type, + ) + await asyncio.wait_for(running_task, timeout) + def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None: # get the operation from the pool operation = self._get_operation(task_id, agent_phase)