add default timeout to wait_for_task (#1122)
This commit is contained in:
@@ -337,7 +337,6 @@ class ForgeAgent:
|
||||
api_key=api_key,
|
||||
close_browser_on_completion=close_browser_on_completion,
|
||||
)
|
||||
await self.async_operation_pool.remove_task(task.task_id)
|
||||
return step, detailed_output, None
|
||||
elif step.status == StepStatus.completed:
|
||||
# TODO (kerem): keep the task object uptodate at all times so that clean_up_task can just use it
|
||||
@@ -1402,6 +1401,7 @@ class ForgeAgent:
|
||||
)
|
||||
return
|
||||
|
||||
await self.async_operation_pool.remove_task(task.task_id)
|
||||
await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task)
|
||||
|
||||
# Wait for all tasks to complete before generating the links for the artifacts
|
||||
|
||||
@@ -4,6 +4,8 @@ from enum import StrEnum
|
||||
import structlog
|
||||
from playwright.async_api import Page
|
||||
|
||||
from skyvern.forge.sdk.core.asyncio_helper import is_aio_task_running
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -48,7 +50,7 @@ class AsyncOperation:
|
||||
return
|
||||
|
||||
def run(self) -> asyncio.Task | None:
|
||||
if self.aio_task is not None and not self.aio_task.done():
|
||||
if self.aio_task is not None and is_aio_task_running(self.aio_task):
|
||||
LOG.warning(
|
||||
"Task already running",
|
||||
task_id=self.task_id,
|
||||
@@ -80,10 +82,10 @@ class AsyncOperationPool:
|
||||
for operation in operations:
|
||||
self._add_operation(task_id, operation)
|
||||
|
||||
def _get_operation(self, task_id: str, operation_type: AgentPhase) -> AsyncOperation | None:
|
||||
return self._operations.get(task_id, {}).get(operation_type, None)
|
||||
def _get_operation(self, task_id: str, agent_phase: AgentPhase) -> AsyncOperation | None:
|
||||
return self._operations.get(task_id, {}).get(agent_phase, None)
|
||||
|
||||
def remove_operations(self, task_id: str) -> None:
|
||||
def _remove_operations(self, task_id: str) -> None:
|
||||
if task_id in self._operations:
|
||||
del self._operations[task_id]
|
||||
|
||||
@@ -91,20 +93,37 @@ class AsyncOperationPool:
|
||||
"""
|
||||
Get all the running/pending aio tasks for the given task_id
|
||||
"""
|
||||
return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if not aio_task.done()]
|
||||
return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if is_aio_task_running(aio_task)]
|
||||
|
||||
def get_aio_task(self, task_id: str, operation_type: str) -> asyncio.Task | None:
|
||||
return self._aio_tasks.get(task_id, {}).get(operation_type, None)
|
||||
|
||||
async def wait_for_task(self, task_id: str, operation_type: str, timeout: float | None = None) -> None:
|
||||
def _remove_aio_tasks(self, task_id: str) -> None:
|
||||
if task_id in self._aio_tasks:
|
||||
del self._aio_tasks[task_id]
|
||||
|
||||
async def wait_for_task(
|
||||
self,
|
||||
task_id: str,
|
||||
operation_type: str,
|
||||
timeout: float | None = 5,
|
||||
) -> None:
|
||||
running_task = self.get_aio_task(task_id=task_id, operation_type=operation_type)
|
||||
if running_task and not running_task.done():
|
||||
if running_task is None or not is_aio_task_running(running_task):
|
||||
return
|
||||
LOG.info(
|
||||
"wait for the running aio task to be done",
|
||||
task_id=task_id,
|
||||
operation_type=operation_type,
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(running_task, timeout)
|
||||
except TimeoutError:
|
||||
LOG.info(
|
||||
"wait for the running aio task to be done",
|
||||
f"Timeout ({timeout}s) while waiting for the running aio task to be done",
|
||||
task_id=task_id,
|
||||
operation_type=operation_type,
|
||||
)
|
||||
await asyncio.wait_for(running_task, timeout)
|
||||
|
||||
def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None:
|
||||
# get the operation from the pool
|
||||
@@ -121,7 +140,7 @@ class AsyncOperationPool:
|
||||
aio_task: asyncio.Task | None = None
|
||||
if operation_type in self._aio_tasks[task_id]:
|
||||
aio_task = self._aio_tasks[task_id][operation_type]
|
||||
if not aio_task.done():
|
||||
if is_aio_task_running(aio_task):
|
||||
LOG.info(
|
||||
"aio task already running",
|
||||
task_id=task_id,
|
||||
@@ -138,11 +157,15 @@ class AsyncOperationPool:
|
||||
async def remove_task(self, task_id: str) -> None:
|
||||
try:
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.gather(*[aio_task for aio_task in self.get_aio_tasks(task_id) if not aio_task.done()])
|
||||
except asyncio.TimeoutError:
|
||||
await asyncio.gather(
|
||||
*[aio_task for aio_task in self.get_aio_tasks(task_id) if is_aio_task_running(aio_task)]
|
||||
)
|
||||
except TimeoutError:
|
||||
LOG.error(
|
||||
f"Timeout (30s) while waiting for pending async tasks for task_id={task_id}",
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
self.remove_operations(task_id)
|
||||
self._remove_aio_tasks(task_id)
|
||||
self._remove_operations(task_id)
|
||||
LOG.info("Successfully removed aio tasks and async operations", task_id=task_id)
|
||||
|
||||
5
skyvern/forge/sdk/core/asyncio_helper.py
Normal file
5
skyvern/forge/sdk/core/asyncio_helper.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
def is_aio_task_running(aio_task: asyncio.Task) -> bool:
|
||||
return not aio_task.done() and not aio_task.cancelled()
|
||||
Reference in New Issue
Block a user