add default timeout to wait_for_task (#1122)

This commit is contained in:
Shuchang Zheng
2024-11-03 22:19:55 -08:00
committed by GitHub
parent 2ac8a1a7d0
commit d0a35622a7
3 changed files with 42 additions and 14 deletions

View File

@@ -337,7 +337,6 @@ class ForgeAgent:
api_key=api_key,
close_browser_on_completion=close_browser_on_completion,
)
await self.async_operation_pool.remove_task(task.task_id)
return step, detailed_output, None
elif step.status == StepStatus.completed:
# TODO (kerem): keep the task object uptodate at all times so that clean_up_task can just use it
@@ -1402,6 +1401,7 @@ class ForgeAgent:
)
return
await self.async_operation_pool.remove_task(task.task_id)
await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task)
# Wait for all tasks to complete before generating the links for the artifacts

View File

@@ -4,6 +4,8 @@ from enum import StrEnum
import structlog
from playwright.async_api import Page
from skyvern.forge.sdk.core.asyncio_helper import is_aio_task_running
LOG = structlog.get_logger()
@@ -48,7 +50,7 @@ class AsyncOperation:
return
def run(self) -> asyncio.Task | None:
if self.aio_task is not None and not self.aio_task.done():
if self.aio_task is not None and is_aio_task_running(self.aio_task):
LOG.warning(
"Task already running",
task_id=self.task_id,
@@ -80,10 +82,10 @@ class AsyncOperationPool:
for operation in operations:
self._add_operation(task_id, operation)
def _get_operation(self, task_id: str, operation_type: AgentPhase) -> AsyncOperation | None:
return self._operations.get(task_id, {}).get(operation_type, None)
def _get_operation(self, task_id: str, agent_phase: AgentPhase) -> AsyncOperation | None:
return self._operations.get(task_id, {}).get(agent_phase, None)
def remove_operations(self, task_id: str) -> None:
def _remove_operations(self, task_id: str) -> None:
if task_id in self._operations:
del self._operations[task_id]
@@ -91,20 +93,37 @@ class AsyncOperationPool:
"""
Get all the running/pending aio tasks for the given task_id
"""
return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if not aio_task.done()]
return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if is_aio_task_running(aio_task)]
def get_aio_task(self, task_id: str, operation_type: str) -> asyncio.Task | None:
return self._aio_tasks.get(task_id, {}).get(operation_type, None)
async def wait_for_task(self, task_id: str, operation_type: str, timeout: float | None = None) -> None:
def _remove_aio_tasks(self, task_id: str) -> None:
if task_id in self._aio_tasks:
del self._aio_tasks[task_id]
async def wait_for_task(
self,
task_id: str,
operation_type: str,
timeout: float | None = 5,
) -> None:
running_task = self.get_aio_task(task_id=task_id, operation_type=operation_type)
if running_task and not running_task.done():
if running_task is None or not is_aio_task_running(running_task):
return
LOG.info(
"wait for the running aio task to be done",
task_id=task_id,
operation_type=operation_type,
)
try:
await asyncio.wait_for(running_task, timeout)
except TimeoutError:
LOG.info(
"wait for the running aio task to be done",
f"Timeout ({timeout}s) while waiting for the running aio task to be done",
task_id=task_id,
operation_type=operation_type,
)
await asyncio.wait_for(running_task, timeout)
def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None:
# get the operation from the pool
@@ -121,7 +140,7 @@ class AsyncOperationPool:
aio_task: asyncio.Task | None = None
if operation_type in self._aio_tasks[task_id]:
aio_task = self._aio_tasks[task_id][operation_type]
if not aio_task.done():
if is_aio_task_running(aio_task):
LOG.info(
"aio task already running",
task_id=task_id,
@@ -138,11 +157,15 @@ class AsyncOperationPool:
async def remove_task(self, task_id: str) -> None:
try:
async with asyncio.timeout(30):
await asyncio.gather(*[aio_task for aio_task in self.get_aio_tasks(task_id) if not aio_task.done()])
except asyncio.TimeoutError:
await asyncio.gather(
*[aio_task for aio_task in self.get_aio_tasks(task_id) if is_aio_task_running(aio_task)]
)
except TimeoutError:
LOG.error(
f"Timeout (30s) while waiting for pending async tasks for task_id={task_id}",
task_id=task_id,
)
self.remove_operations(task_id)
self._remove_aio_tasks(task_id)
self._remove_operations(task_id)
LOG.info("Successfully removed aio tasks and async operations", task_id=task_id)

View File

@@ -0,0 +1,5 @@
import asyncio
def is_aio_task_running(aio_task: asyncio.Task) -> bool:
return not aio_task.done() and not aio_task.cancelled()