add default timeout to wait_for_task (#1122)
This commit is contained in:
@@ -337,7 +337,6 @@ class ForgeAgent:
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
close_browser_on_completion=close_browser_on_completion,
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
)
|
)
|
||||||
await self.async_operation_pool.remove_task(task.task_id)
|
|
||||||
return step, detailed_output, None
|
return step, detailed_output, None
|
||||||
elif step.status == StepStatus.completed:
|
elif step.status == StepStatus.completed:
|
||||||
# TODO (kerem): keep the task object uptodate at all times so that clean_up_task can just use it
|
# TODO (kerem): keep the task object uptodate at all times so that clean_up_task can just use it
|
||||||
@@ -1402,6 +1401,7 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
await self.async_operation_pool.remove_task(task.task_id)
|
||||||
await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task)
|
await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task)
|
||||||
|
|
||||||
# Wait for all tasks to complete before generating the links for the artifacts
|
# Wait for all tasks to complete before generating the links for the artifacts
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from enum import StrEnum
|
|||||||
import structlog
|
import structlog
|
||||||
from playwright.async_api import Page
|
from playwright.async_api import Page
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.core.asyncio_helper import is_aio_task_running
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -48,7 +50,7 @@ class AsyncOperation:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def run(self) -> asyncio.Task | None:
|
def run(self) -> asyncio.Task | None:
|
||||||
if self.aio_task is not None and not self.aio_task.done():
|
if self.aio_task is not None and is_aio_task_running(self.aio_task):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Task already running",
|
"Task already running",
|
||||||
task_id=self.task_id,
|
task_id=self.task_id,
|
||||||
@@ -80,10 +82,10 @@ class AsyncOperationPool:
|
|||||||
for operation in operations:
|
for operation in operations:
|
||||||
self._add_operation(task_id, operation)
|
self._add_operation(task_id, operation)
|
||||||
|
|
||||||
def _get_operation(self, task_id: str, operation_type: AgentPhase) -> AsyncOperation | None:
|
def _get_operation(self, task_id: str, agent_phase: AgentPhase) -> AsyncOperation | None:
|
||||||
return self._operations.get(task_id, {}).get(operation_type, None)
|
return self._operations.get(task_id, {}).get(agent_phase, None)
|
||||||
|
|
||||||
def remove_operations(self, task_id: str) -> None:
|
def _remove_operations(self, task_id: str) -> None:
|
||||||
if task_id in self._operations:
|
if task_id in self._operations:
|
||||||
del self._operations[task_id]
|
del self._operations[task_id]
|
||||||
|
|
||||||
@@ -91,20 +93,37 @@ class AsyncOperationPool:
|
|||||||
"""
|
"""
|
||||||
Get all the running/pending aio tasks for the given task_id
|
Get all the running/pending aio tasks for the given task_id
|
||||||
"""
|
"""
|
||||||
return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if not aio_task.done()]
|
return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if is_aio_task_running(aio_task)]
|
||||||
|
|
||||||
def get_aio_task(self, task_id: str, operation_type: str) -> asyncio.Task | None:
|
def get_aio_task(self, task_id: str, operation_type: str) -> asyncio.Task | None:
|
||||||
return self._aio_tasks.get(task_id, {}).get(operation_type, None)
|
return self._aio_tasks.get(task_id, {}).get(operation_type, None)
|
||||||
|
|
||||||
async def wait_for_task(self, task_id: str, operation_type: str, timeout: float | None = None) -> None:
|
def _remove_aio_tasks(self, task_id: str) -> None:
|
||||||
|
if task_id in self._aio_tasks:
|
||||||
|
del self._aio_tasks[task_id]
|
||||||
|
|
||||||
|
async def wait_for_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
operation_type: str,
|
||||||
|
timeout: float | None = 5,
|
||||||
|
) -> None:
|
||||||
running_task = self.get_aio_task(task_id=task_id, operation_type=operation_type)
|
running_task = self.get_aio_task(task_id=task_id, operation_type=operation_type)
|
||||||
if running_task and not running_task.done():
|
if running_task is None or not is_aio_task_running(running_task):
|
||||||
|
return
|
||||||
|
LOG.info(
|
||||||
|
"wait for the running aio task to be done",
|
||||||
|
task_id=task_id,
|
||||||
|
operation_type=operation_type,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(running_task, timeout)
|
||||||
|
except TimeoutError:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"wait for the running aio task to be done",
|
f"Timeout ({timeout}s) while waiting for the running aio task to be done",
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
operation_type=operation_type,
|
operation_type=operation_type,
|
||||||
)
|
)
|
||||||
await asyncio.wait_for(running_task, timeout)
|
|
||||||
|
|
||||||
def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None:
|
def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None:
|
||||||
# get the operation from the pool
|
# get the operation from the pool
|
||||||
@@ -121,7 +140,7 @@ class AsyncOperationPool:
|
|||||||
aio_task: asyncio.Task | None = None
|
aio_task: asyncio.Task | None = None
|
||||||
if operation_type in self._aio_tasks[task_id]:
|
if operation_type in self._aio_tasks[task_id]:
|
||||||
aio_task = self._aio_tasks[task_id][operation_type]
|
aio_task = self._aio_tasks[task_id][operation_type]
|
||||||
if not aio_task.done():
|
if is_aio_task_running(aio_task):
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"aio task already running",
|
"aio task already running",
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@@ -138,11 +157,15 @@ class AsyncOperationPool:
|
|||||||
async def remove_task(self, task_id: str) -> None:
|
async def remove_task(self, task_id: str) -> None:
|
||||||
try:
|
try:
|
||||||
async with asyncio.timeout(30):
|
async with asyncio.timeout(30):
|
||||||
await asyncio.gather(*[aio_task for aio_task in self.get_aio_tasks(task_id) if not aio_task.done()])
|
await asyncio.gather(
|
||||||
except asyncio.TimeoutError:
|
*[aio_task for aio_task in self.get_aio_tasks(task_id) if is_aio_task_running(aio_task)]
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
f"Timeout (30s) while waiting for pending async tasks for task_id={task_id}",
|
f"Timeout (30s) while waiting for pending async tasks for task_id={task_id}",
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.remove_operations(task_id)
|
self._remove_aio_tasks(task_id)
|
||||||
|
self._remove_operations(task_id)
|
||||||
|
LOG.info("Successfully removed aio tasks and async operations", task_id=task_id)
|
||||||
|
|||||||
5
skyvern/forge/sdk/core/asyncio_helper.py
Normal file
5
skyvern/forge/sdk/core/asyncio_helper.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def is_aio_task_running(aio_task: asyncio.Task) -> bool:
|
||||||
|
return not aio_task.done() and not aio_task.cancelled()
|
||||||
Reference in New Issue
Block a user