shu/task cancel agent logic (#521)

This commit is contained in:
Kerem Yilmaz
2024-06-26 15:25:15 -07:00
committed by GitHub
parent 8155be9ff2
commit bf81b7df53
5 changed files with 131 additions and 34 deletions

View File

@@ -350,3 +350,15 @@ class EmptySelect(SkyvernException):
super().__init__( super().__init__(
f"nothing is selected, try to select again. element_id={element_id}", f"nothing is selected, try to select again. element_id={element_id}",
) )
class TaskAlreadyCanceled(SkyvernHTTPException):
def __init__(self, new_status: str, task_id: str):
super().__init__(
f"Invalid task status transition to {new_status} for {task_id} because task is already canceled"
)
class InvalidTaskStatusTransition(SkyvernHTTPException):
def __init__(self, old_status: str, new_status: str, task_id: str):
super().__init__(f"Invalid task status transition from {old_status} to {new_status} for {task_id}")

View File

@@ -18,11 +18,13 @@ from skyvern.exceptions import (
FailedToNavigateToUrl, FailedToNavigateToUrl,
FailedToSendWebhook, FailedToSendWebhook,
FailedToTakeScreenshot, FailedToTakeScreenshot,
InvalidTaskStatusTransition,
InvalidWorkflowTaskURLState, InvalidWorkflowTaskURLState,
MissingBrowserStatePage, MissingBrowserStatePage,
SkyvernException, SkyvernException,
StepTerminationError, StepTerminationError,
StepUnableToExecuteError, StepUnableToExecuteError,
TaskAlreadyCanceled,
TaskNotFound, TaskNotFound,
) )
from skyvern.forge import app from skyvern.forge import app
@@ -204,6 +206,24 @@ class ForgeAgent:
# if a download happens during the step execution. # if a download happens during the step execution.
complete_on_download: bool = False, complete_on_download: bool = False,
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=organization.organization_id)
if refreshed_task:
task = refreshed_task
if task.status == TaskStatus.canceled:
LOG.info(
"Task is canceled, stopping execution",
task_id=task.task_id,
)
step = await self.update_step(
step,
status=StepStatus.canceled,
is_last=True,
)
# We don't send task response for now as the task is canceled
# TODO: shall we send task response here?
return step, None, None
next_step: Step | None = None next_step: Step | None = None
detailed_output: DetailedAgentStepOutput | None = None detailed_output: DetailedAgentStepOutput | None = None
num_files_before = 0 num_files_before = 0
@@ -348,14 +368,14 @@ class ForgeAgent:
step_id=step.step_id, step_id=step.step_id,
exc_info=True, exc_info=True,
) )
await self.fail_task(task, step, e.message) is_task_marked_as_failed = await self.fail_task(task, step, e.message)
await self.send_task_response( if is_task_marked_as_failed:
task=task, await self.send_task_response(
last_step=step, task=task,
api_key=api_key, last_step=step,
close_browser_on_completion=close_browser_on_completion, api_key=api_key,
skip_cleanup=True, close_browser_on_completion=close_browser_on_completion,
) )
return step, detailed_output, None return step, detailed_output, None
except FailedToSendWebhook: except FailedToSendWebhook:
LOG.exception( LOG.exception(
@@ -376,15 +396,31 @@ class ForgeAgent:
error_message=e.error_message, error_message=e.error_message,
) )
failure_reason = f"Failed to navigate to URL. URL:{e.url}, Error:{e.error_message}" failure_reason = f"Failed to navigate to URL. URL:{e.url}, Error:{e.error_message}"
await self.fail_task(task, step, failure_reason) is_task_marked_as_failed = await self.fail_task(task, step, failure_reason)
await self.send_task_response( if is_task_marked_as_failed:
task=task, await self.send_task_response(
last_step=step, task=task,
api_key=api_key, last_step=step,
close_browser_on_completion=close_browser_on_completion, api_key=api_key,
skip_artifacts=True, close_browser_on_completion=close_browser_on_completion,
) skip_artifacts=True,
)
return step, detailed_output, next_step return step, detailed_output, next_step
except TaskAlreadyCanceled:
LOG.info(
"Task is already canceled, stopping execution",
task_id=task.task_id,
)
# We don't send task response for now as the task is canceled
return step, detailed_output, None
except InvalidTaskStatusTransition:
LOG.warning(
"Invalid task status transition",
task_id=task.task_id,
step_id=step.step_id,
)
# TODO: shall we send task response here?
return step, detailed_output, None
except Exception as e: except Exception as e:
LOG.exception( LOG.exception(
"Got an unexpected exception in step, fail the task", "Got an unexpected exception in step, fail the task",
@@ -396,29 +432,44 @@ class ForgeAgent:
if isinstance(e, SkyvernException): if isinstance(e, SkyvernException):
failure_reason = f"unexpected SkyvernException({e.__class__.__name__})" failure_reason = f"unexpected SkyvernException({e.__class__.__name__})"
await self.fail_task(task, step, failure_reason) is_task_marked_as_failed = await self.fail_task(task, step, failure_reason)
await self.send_task_response( if is_task_marked_as_failed:
task=task, await self.send_task_response(
last_step=step, task=task,
api_key=api_key, last_step=step,
close_browser_on_completion=close_browser_on_completion, api_key=api_key,
) close_browser_on_completion=close_browser_on_completion,
)
return step, detailed_output, None return step, detailed_output, None
async def fail_task(self, task: Task, step: Step | None, reason: str | None) -> None: async def fail_task(self, task: Task, step: Step | None, reason: str | None) -> bool:
try: try:
if step is not None and step.status.can_update_to(StepStatus.failed): if step is not None:
await self.update_step( await self.update_step(
step=step, step=step,
status=StepStatus.failed, status=StepStatus.failed,
) )
if task.status.can_update_to(TaskStatus.failed): await self.update_task(
await self.update_task( task,
task, status=TaskStatus.failed,
status=TaskStatus.failed, failure_reason=reason,
failure_reason=reason, )
) return True
except TaskAlreadyCanceled:
LOG.info(
"Task is already canceled. Can't fail the task.",
task_id=task.task_id,
step_id=step.step_id if step else "",
)
return False
except InvalidTaskStatusTransition:
LOG.warning(
"Invalid task status transition while failing a task",
task_id=task.task_id,
step_id=step.step_id if step else "",
)
return False
except Exception: except Exception:
LOG.exception( LOG.exception(
"Failed to update status and failure reason in database. Task might going to be time_out", "Failed to update status and failure reason in database. Task might going to be time_out",
@@ -426,6 +477,7 @@ class ForgeAgent:
step_id=step.step_id if step else "", step_id=step.step_id if step else "",
reason=reason, reason=reason,
) )
return True
async def agent_step( async def agent_step(
self, self,
@@ -1294,6 +1346,11 @@ class ForgeAgent:
extracted_information: dict[str, Any] | list | str | None = None, extracted_information: dict[str, Any] | list | str | None = None,
failure_reason: str | None = None, failure_reason: str | None = None,
) -> Task: ) -> Task:
# refresh task from db to get the latest status
task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
if task_from_db:
task = task_from_db
task.validate_update(status, extracted_information, failure_reason) task.validate_update(status, extracted_information, failure_reason)
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if status is not None: if status is not None:

View File

@@ -15,13 +15,15 @@ class StepStatus(StrEnum):
running = "running" running = "running"
failed = "failed" failed = "failed"
completed = "completed" completed = "completed"
canceled = "canceled"
def can_update_to(self, new_status: StepStatus) -> bool: def can_update_to(self, new_status: StepStatus) -> bool:
allowed_transitions: dict[StepStatus, set[StepStatus]] = { allowed_transitions: dict[StepStatus, set[StepStatus]] = {
StepStatus.created: {StepStatus.running}, StepStatus.created: {StepStatus.running, StepStatus.canceled},
StepStatus.running: {StepStatus.completed, StepStatus.failed}, StepStatus.running: {StepStatus.completed, StepStatus.failed, StepStatus.canceled},
StepStatus.failed: set(), StepStatus.failed: set(),
StepStatus.completed: set(), StepStatus.completed: set(),
StepStatus.canceled: set(),
} }
return new_status in allowed_transitions[self] return new_status in allowed_transitions[self]

View File

@@ -281,6 +281,22 @@ async def get_task(
) )
@base_router.post("/tasks/{task_id}/cancel")
@base_router.post("/tasks/{task_id}/cancel/", include_in_schema=False)
async def cancel_task(
task_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> None:
analytics.capture("skyvern-oss-agent-task-get")
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
if not task_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task not found {task_id}",
)
await app.agent.update_task(task_obj, status=TaskStatus.canceled)
@base_router.post( @base_router.post(
"/tasks/{task_id}/retry_webhook", "/tasks/{task_id}/retry_webhook",
tags=["agent"], tags=["agent"],

View File

@@ -6,6 +6,8 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled
class ProxyLocation(StrEnum): class ProxyLocation(StrEnum):
US_CA = "US-CA" US_CA = "US-CA"
@@ -80,6 +82,7 @@ class TaskStatus(StrEnum):
failed = "failed" failed = "failed"
terminated = "terminated" terminated = "terminated"
completed = "completed" completed = "completed"
canceled = "canceled"
def is_final(self) -> bool: def is_final(self) -> bool:
return self in { return self in {
@@ -87,6 +90,7 @@ class TaskStatus(StrEnum):
TaskStatus.terminated, TaskStatus.terminated,
TaskStatus.completed, TaskStatus.completed,
TaskStatus.timed_out, TaskStatus.timed_out,
TaskStatus.canceled,
} }
def can_update_to(self, new_status: TaskStatus) -> bool: def can_update_to(self, new_status: TaskStatus) -> bool:
@@ -96,22 +100,26 @@ class TaskStatus(StrEnum):
TaskStatus.running, TaskStatus.running,
TaskStatus.timed_out, TaskStatus.timed_out,
TaskStatus.failed, TaskStatus.failed,
TaskStatus.canceled,
}, },
TaskStatus.queued: { TaskStatus.queued: {
TaskStatus.running, TaskStatus.running,
TaskStatus.timed_out, TaskStatus.timed_out,
TaskStatus.failed, TaskStatus.failed,
TaskStatus.canceled,
}, },
TaskStatus.running: { TaskStatus.running: {
TaskStatus.completed, TaskStatus.completed,
TaskStatus.failed, TaskStatus.failed,
TaskStatus.terminated, TaskStatus.terminated,
TaskStatus.timed_out, TaskStatus.timed_out,
TaskStatus.canceled,
}, },
TaskStatus.failed: set(), TaskStatus.failed: set(),
TaskStatus.terminated: set(), TaskStatus.terminated: set(),
TaskStatus.completed: set(), TaskStatus.completed: set(),
TaskStatus.timed_out: set(), TaskStatus.timed_out: set(),
TaskStatus.canceled: {TaskStatus.completed},
} }
return new_status in allowed_transitions[self] return new_status in allowed_transitions[self]
@@ -175,7 +183,9 @@ class Task(TaskRequest):
old_status = self.status old_status = self.status
if not old_status.can_update_to(status): if not old_status.can_update_to(status):
raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}") if old_status == TaskStatus.canceled:
raise TaskAlreadyCanceled(new_status=status, task_id=self.task_id)
raise InvalidTaskStatusTransition(old_status=old_status, new_status=status, task_id=self.task_id)
if status.requires_failure_reason() and failure_reason is None: if status.requires_failure_reason() and failure_reason is None:
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}") raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")