shu/task cancel agent logic (#521)
This commit is contained in:
@@ -15,13 +15,15 @@ class StepStatus(StrEnum):
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
canceled = "canceled"
|
||||
|
||||
def can_update_to(self, new_status: StepStatus) -> bool:
|
||||
allowed_transitions: dict[StepStatus, set[StepStatus]] = {
|
||||
StepStatus.created: {StepStatus.running},
|
||||
StepStatus.running: {StepStatus.completed, StepStatus.failed},
|
||||
StepStatus.created: {StepStatus.running, StepStatus.canceled},
|
||||
StepStatus.running: {StepStatus.completed, StepStatus.failed, StepStatus.canceled},
|
||||
StepStatus.failed: set(),
|
||||
StepStatus.completed: set(),
|
||||
StepStatus.canceled: set(),
|
||||
}
|
||||
return new_status in allowed_transitions[self]
|
||||
|
||||
|
||||
@@ -281,6 +281,22 @@ async def get_task(
|
||||
)
|
||||
|
||||
|
||||
@base_router.post("/tasks/{task_id}/cancel")
|
||||
@base_router.post("/tasks/{task_id}/cancel/", include_in_schema=False)
|
||||
async def cancel_task(
|
||||
task_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> None:
|
||||
analytics.capture("skyvern-oss-agent-task-get")
|
||||
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Task not found {task_id}",
|
||||
)
|
||||
await app.agent.update_task(task_obj, status=TaskStatus.canceled)
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/tasks/{task_id}/retry_webhook",
|
||||
tags=["agent"],
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled
|
||||
|
||||
|
||||
class ProxyLocation(StrEnum):
|
||||
US_CA = "US-CA"
|
||||
@@ -80,6 +82,7 @@ class TaskStatus(StrEnum):
|
||||
failed = "failed"
|
||||
terminated = "terminated"
|
||||
completed = "completed"
|
||||
canceled = "canceled"
|
||||
|
||||
def is_final(self) -> bool:
|
||||
return self in {
|
||||
@@ -87,6 +90,7 @@ class TaskStatus(StrEnum):
|
||||
TaskStatus.terminated,
|
||||
TaskStatus.completed,
|
||||
TaskStatus.timed_out,
|
||||
TaskStatus.canceled,
|
||||
}
|
||||
|
||||
def can_update_to(self, new_status: TaskStatus) -> bool:
|
||||
@@ -96,22 +100,26 @@ class TaskStatus(StrEnum):
|
||||
TaskStatus.running,
|
||||
TaskStatus.timed_out,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.canceled,
|
||||
},
|
||||
TaskStatus.queued: {
|
||||
TaskStatus.running,
|
||||
TaskStatus.timed_out,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.canceled,
|
||||
},
|
||||
TaskStatus.running: {
|
||||
TaskStatus.completed,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.terminated,
|
||||
TaskStatus.timed_out,
|
||||
TaskStatus.canceled,
|
||||
},
|
||||
TaskStatus.failed: set(),
|
||||
TaskStatus.terminated: set(),
|
||||
TaskStatus.completed: set(),
|
||||
TaskStatus.timed_out: set(),
|
||||
TaskStatus.canceled: {TaskStatus.completed},
|
||||
}
|
||||
return new_status in allowed_transitions[self]
|
||||
|
||||
@@ -175,7 +183,9 @@ class Task(TaskRequest):
|
||||
old_status = self.status
|
||||
|
||||
if not old_status.can_update_to(status):
|
||||
raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}")
|
||||
if old_status == TaskStatus.canceled:
|
||||
raise TaskAlreadyCanceled(new_status=status, task_id=self.task_id)
|
||||
raise InvalidTaskStatusTransition(old_status=old_status, new_status=status, task_id=self.task_id)
|
||||
|
||||
if status.requires_failure_reason() and failure_reason is None:
|
||||
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")
|
||||
|
||||
Reference in New Issue
Block a user