diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 81c73d05..90ff7ba1 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -350,3 +350,15 @@ class EmptySelect(SkyvernException): super().__init__( f"nothing is selected, try to select again. element_id={element_id}", ) + + +class TaskAlreadyCanceled(SkyvernHTTPException): + def __init__(self, new_status: str, task_id: str): + super().__init__( + f"Invalid task status transition to {new_status} for {task_id} because task is already canceled" + ) + + +class InvalidTaskStatusTransition(SkyvernHTTPException): + def __init__(self, old_status: str, new_status: str, task_id: str): + super().__init__(f"Invalid task status transition from {old_status} to {new_status} for {task_id}") diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 9ce678a0..efdae702 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -18,11 +18,13 @@ from skyvern.exceptions import ( FailedToNavigateToUrl, FailedToSendWebhook, FailedToTakeScreenshot, + InvalidTaskStatusTransition, InvalidWorkflowTaskURLState, MissingBrowserStatePage, SkyvernException, StepTerminationError, StepUnableToExecuteError, + TaskAlreadyCanceled, TaskNotFound, ) from skyvern.forge import app @@ -204,6 +206,24 @@ class ForgeAgent: # if a download happens during the step execution. complete_on_download: bool = False, ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: + refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=organization.organization_id) + if refreshed_task: + task = refreshed_task + + if task.status == TaskStatus.canceled: + LOG.info( + "Task is canceled, stopping execution", + task_id=task.task_id, + ) + step = await self.update_step( + step, + status=StepStatus.canceled, + is_last=True, + ) + # We don't send task response for now as the task is canceled + # TODO: shall we send task response here? + return step, None, None + next_step: Step | None = None detailed_output: DetailedAgentStepOutput | None = None num_files_before = 0 @@ -348,14 +368,14 @@ class ForgeAgent: step_id=step.step_id, exc_info=True, ) - await self.fail_task(task, step, e.message) - await self.send_task_response( - task=task, - last_step=step, - api_key=api_key, - close_browser_on_completion=close_browser_on_completion, - skip_cleanup=True, - ) + is_task_marked_as_failed = await self.fail_task(task, step, e.message) + if is_task_marked_as_failed: + await self.send_task_response( + task=task, + last_step=step, + api_key=api_key, + close_browser_on_completion=close_browser_on_completion, + ) return step, detailed_output, None except FailedToSendWebhook: LOG.exception( @@ -376,15 +396,31 @@ class ForgeAgent: error_message=e.error_message, ) failure_reason = f"Failed to navigate to URL. URL:{e.url}, Error:{e.error_message}" - await self.fail_task(task, step, failure_reason) - await self.send_task_response( - task=task, - last_step=step, - api_key=api_key, - close_browser_on_completion=close_browser_on_completion, - skip_artifacts=True, - ) + is_task_marked_as_failed = await self.fail_task(task, step, failure_reason) + if is_task_marked_as_failed: + await self.send_task_response( + task=task, + last_step=step, + api_key=api_key, + close_browser_on_completion=close_browser_on_completion, + skip_artifacts=True, + ) return step, detailed_output, next_step + except TaskAlreadyCanceled: + LOG.info( + "Task is already canceled, stopping execution", + task_id=task.task_id, + ) + # We don't send task response for now as the task is canceled + return step, detailed_output, None + except InvalidTaskStatusTransition: + LOG.warning( + "Invalid task status transition", + task_id=task.task_id, + step_id=step.step_id, + ) + # TODO: shall we send task response here? + return step, detailed_output, None except Exception as e: LOG.exception( "Got an unexpected exception in step, fail the task", @@ -396,29 +432,44 @@ class ForgeAgent: if isinstance(e, SkyvernException): failure_reason = f"unexpected SkyvernException({e.__class__.__name__})" - await self.fail_task(task, step, failure_reason) - await self.send_task_response( - task=task, - last_step=step, - api_key=api_key, - close_browser_on_completion=close_browser_on_completion, - ) + is_task_marked_as_failed = await self.fail_task(task, step, failure_reason) + if is_task_marked_as_failed: + await self.send_task_response( + task=task, + last_step=step, + api_key=api_key, + close_browser_on_completion=close_browser_on_completion, + ) return step, detailed_output, None - async def fail_task(self, task: Task, step: Step | None, reason: str | None) -> None: + async def fail_task(self, task: Task, step: Step | None, reason: str | None) -> bool: try: - if step is not None and step.status.can_update_to(StepStatus.failed): + if step is not None: await self.update_step( step=step, status=StepStatus.failed, ) - if task.status.can_update_to(TaskStatus.failed): - await self.update_task( - task, - status=TaskStatus.failed, - failure_reason=reason, - ) + await self.update_task( + task, + status=TaskStatus.failed, + failure_reason=reason, + ) + return True + except TaskAlreadyCanceled: + LOG.info( + "Task is already canceled. Can't fail the task.", + task_id=task.task_id, + step_id=step.step_id if step else "", + ) + return False + except InvalidTaskStatusTransition: + LOG.warning( + "Invalid task status transition while failing a task", + task_id=task.task_id, + step_id=step.step_id if step else "", + ) + return False except Exception: LOG.exception( "Failed to update status and failure reason in database. Task might going to be time_out", @@ -426,6 +477,7 @@ class ForgeAgent: step_id=step.step_id if step else "", reason=reason, ) + return True async def agent_step( self, @@ -1294,6 +1346,11 @@ class ForgeAgent: extracted_information: dict[str, Any] | list | str | None = None, failure_reason: str | None = None, ) -> Task: + # refresh task from db to get the latest status + task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id) + if task_from_db: + task = task_from_db + task.validate_update(status, extracted_information, failure_reason) updates: dict[str, Any] = {} if status is not None: diff --git a/skyvern/forge/sdk/models.py b/skyvern/forge/sdk/models.py index a7feca1e..4d17a93c 100644 --- a/skyvern/forge/sdk/models.py +++ b/skyvern/forge/sdk/models.py @@ -15,13 +15,15 @@ class StepStatus(StrEnum): running = "running" failed = "failed" completed = "completed" + canceled = "canceled" def can_update_to(self, new_status: StepStatus) -> bool: allowed_transitions: dict[StepStatus, set[StepStatus]] = { - StepStatus.created: {StepStatus.running}, - StepStatus.running: {StepStatus.completed, StepStatus.failed}, + StepStatus.created: {StepStatus.running, StepStatus.canceled}, + StepStatus.running: {StepStatus.completed, StepStatus.failed, StepStatus.canceled}, StepStatus.failed: set(), StepStatus.completed: set(), + StepStatus.canceled: set(), } return new_status in allowed_transitions[self] diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index e1816761..cab8126f 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -281,6 +281,22 @@ async def get_task( ) +@base_router.post("/tasks/{task_id}/cancel") +@base_router.post("/tasks/{task_id}/cancel/", include_in_schema=False) +async def cancel_task( + task_id: str, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> None: + analytics.capture("skyvern-oss-agent-task-get") + task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id) + if not task_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Task not found {task_id}", + ) + await app.agent.update_task(task_obj, status=TaskStatus.canceled) + + @base_router.post( "/tasks/{task_id}/retry_webhook", tags=["agent"], diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 6ecd7560..34c28111 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -6,6 +6,8 @@ from typing import Any from pydantic import BaseModel, Field +from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled + class ProxyLocation(StrEnum): US_CA = "US-CA" @@ -80,6 +82,7 @@ class TaskStatus(StrEnum): failed = "failed" terminated = "terminated" completed = "completed" + canceled = "canceled" def is_final(self) -> bool: return self in { @@ -87,6 +90,7 @@ class TaskStatus(StrEnum): TaskStatus.terminated, TaskStatus.completed, TaskStatus.timed_out, + TaskStatus.canceled, } def can_update_to(self, new_status: TaskStatus) -> bool: @@ -96,22 +100,26 @@ class TaskStatus(StrEnum): TaskStatus.running, TaskStatus.timed_out, TaskStatus.failed, + TaskStatus.canceled, }, TaskStatus.queued: { TaskStatus.running, TaskStatus.timed_out, TaskStatus.failed, + TaskStatus.canceled, }, TaskStatus.running: { TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated, TaskStatus.timed_out, + TaskStatus.canceled, }, TaskStatus.failed: set(), TaskStatus.terminated: set(), TaskStatus.completed: set(), TaskStatus.timed_out: set(), + TaskStatus.canceled: {TaskStatus.completed}, } return new_status in allowed_transitions[self] @@ -175,7 +183,9 @@ class Task(TaskRequest): old_status = self.status if not old_status.can_update_to(status): - raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}") + if old_status == TaskStatus.canceled: + raise TaskAlreadyCanceled(new_status=status, task_id=self.task_id) + raise InvalidTaskStatusTransition(old_status=old_status, new_status=status, task_id=self.task_id) if status.requires_failure_reason() and failure_reason is None: raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")