shu/task cancel agent logic (#521)
This commit is contained in:
@@ -350,3 +350,15 @@ class EmptySelect(SkyvernException):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
f"nothing is selected, try to select again. element_id={element_id}",
|
f"nothing is selected, try to select again. element_id={element_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskAlreadyCanceled(SkyvernHTTPException):
|
||||||
|
def __init__(self, new_status: str, task_id: str):
|
||||||
|
super().__init__(
|
||||||
|
f"Invalid task status transition to {new_status} for {task_id} because task is already canceled"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTaskStatusTransition(SkyvernHTTPException):
|
||||||
|
def __init__(self, old_status: str, new_status: str, task_id: str):
|
||||||
|
super().__init__(f"Invalid task status transition from {old_status} to {new_status} for {task_id}")
|
||||||
|
|||||||
@@ -18,11 +18,13 @@ from skyvern.exceptions import (
|
|||||||
FailedToNavigateToUrl,
|
FailedToNavigateToUrl,
|
||||||
FailedToSendWebhook,
|
FailedToSendWebhook,
|
||||||
FailedToTakeScreenshot,
|
FailedToTakeScreenshot,
|
||||||
|
InvalidTaskStatusTransition,
|
||||||
InvalidWorkflowTaskURLState,
|
InvalidWorkflowTaskURLState,
|
||||||
MissingBrowserStatePage,
|
MissingBrowserStatePage,
|
||||||
SkyvernException,
|
SkyvernException,
|
||||||
StepTerminationError,
|
StepTerminationError,
|
||||||
StepUnableToExecuteError,
|
StepUnableToExecuteError,
|
||||||
|
TaskAlreadyCanceled,
|
||||||
TaskNotFound,
|
TaskNotFound,
|
||||||
)
|
)
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
@@ -204,6 +206,24 @@ class ForgeAgent:
|
|||||||
# if a download happens during the step execution.
|
# if a download happens during the step execution.
|
||||||
complete_on_download: bool = False,
|
complete_on_download: bool = False,
|
||||||
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
|
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
|
||||||
|
refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=organization.organization_id)
|
||||||
|
if refreshed_task:
|
||||||
|
task = refreshed_task
|
||||||
|
|
||||||
|
if task.status == TaskStatus.canceled:
|
||||||
|
LOG.info(
|
||||||
|
"Task is canceled, stopping execution",
|
||||||
|
task_id=task.task_id,
|
||||||
|
)
|
||||||
|
step = await self.update_step(
|
||||||
|
step,
|
||||||
|
status=StepStatus.canceled,
|
||||||
|
is_last=True,
|
||||||
|
)
|
||||||
|
# We don't send task response for now as the task is canceled
|
||||||
|
# TODO: shall we send task response here?
|
||||||
|
return step, None, None
|
||||||
|
|
||||||
next_step: Step | None = None
|
next_step: Step | None = None
|
||||||
detailed_output: DetailedAgentStepOutput | None = None
|
detailed_output: DetailedAgentStepOutput | None = None
|
||||||
num_files_before = 0
|
num_files_before = 0
|
||||||
@@ -348,14 +368,14 @@ class ForgeAgent:
|
|||||||
step_id=step.step_id,
|
step_id=step.step_id,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
await self.fail_task(task, step, e.message)
|
is_task_marked_as_failed = await self.fail_task(task, step, e.message)
|
||||||
await self.send_task_response(
|
if is_task_marked_as_failed:
|
||||||
task=task,
|
await self.send_task_response(
|
||||||
last_step=step,
|
task=task,
|
||||||
api_key=api_key,
|
last_step=step,
|
||||||
close_browser_on_completion=close_browser_on_completion,
|
api_key=api_key,
|
||||||
skip_cleanup=True,
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
)
|
)
|
||||||
return step, detailed_output, None
|
return step, detailed_output, None
|
||||||
except FailedToSendWebhook:
|
except FailedToSendWebhook:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
@@ -376,15 +396,31 @@ class ForgeAgent:
|
|||||||
error_message=e.error_message,
|
error_message=e.error_message,
|
||||||
)
|
)
|
||||||
failure_reason = f"Failed to navigate to URL. URL:{e.url}, Error:{e.error_message}"
|
failure_reason = f"Failed to navigate to URL. URL:{e.url}, Error:{e.error_message}"
|
||||||
await self.fail_task(task, step, failure_reason)
|
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason)
|
||||||
await self.send_task_response(
|
if is_task_marked_as_failed:
|
||||||
task=task,
|
await self.send_task_response(
|
||||||
last_step=step,
|
task=task,
|
||||||
api_key=api_key,
|
last_step=step,
|
||||||
close_browser_on_completion=close_browser_on_completion,
|
api_key=api_key,
|
||||||
skip_artifacts=True,
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
)
|
skip_artifacts=True,
|
||||||
|
)
|
||||||
return step, detailed_output, next_step
|
return step, detailed_output, next_step
|
||||||
|
except TaskAlreadyCanceled:
|
||||||
|
LOG.info(
|
||||||
|
"Task is already canceled, stopping execution",
|
||||||
|
task_id=task.task_id,
|
||||||
|
)
|
||||||
|
# We don't send task response for now as the task is canceled
|
||||||
|
return step, detailed_output, None
|
||||||
|
except InvalidTaskStatusTransition:
|
||||||
|
LOG.warning(
|
||||||
|
"Invalid task status transition",
|
||||||
|
task_id=task.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
)
|
||||||
|
# TODO: shall we send task response here?
|
||||||
|
return step, detailed_output, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"Got an unexpected exception in step, fail the task",
|
"Got an unexpected exception in step, fail the task",
|
||||||
@@ -396,29 +432,44 @@ class ForgeAgent:
|
|||||||
if isinstance(e, SkyvernException):
|
if isinstance(e, SkyvernException):
|
||||||
failure_reason = f"unexpected SkyvernException({e.__class__.__name__})"
|
failure_reason = f"unexpected SkyvernException({e.__class__.__name__})"
|
||||||
|
|
||||||
await self.fail_task(task, step, failure_reason)
|
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason)
|
||||||
await self.send_task_response(
|
if is_task_marked_as_failed:
|
||||||
task=task,
|
await self.send_task_response(
|
||||||
last_step=step,
|
task=task,
|
||||||
api_key=api_key,
|
last_step=step,
|
||||||
close_browser_on_completion=close_browser_on_completion,
|
api_key=api_key,
|
||||||
)
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
|
)
|
||||||
return step, detailed_output, None
|
return step, detailed_output, None
|
||||||
|
|
||||||
async def fail_task(self, task: Task, step: Step | None, reason: str | None) -> None:
|
async def fail_task(self, task: Task, step: Step | None, reason: str | None) -> bool:
|
||||||
try:
|
try:
|
||||||
if step is not None and step.status.can_update_to(StepStatus.failed):
|
if step is not None:
|
||||||
await self.update_step(
|
await self.update_step(
|
||||||
step=step,
|
step=step,
|
||||||
status=StepStatus.failed,
|
status=StepStatus.failed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if task.status.can_update_to(TaskStatus.failed):
|
await self.update_task(
|
||||||
await self.update_task(
|
task,
|
||||||
task,
|
status=TaskStatus.failed,
|
||||||
status=TaskStatus.failed,
|
failure_reason=reason,
|
||||||
failure_reason=reason,
|
)
|
||||||
)
|
return True
|
||||||
|
except TaskAlreadyCanceled:
|
||||||
|
LOG.info(
|
||||||
|
"Task is already canceled. Can't fail the task.",
|
||||||
|
task_id=task.task_id,
|
||||||
|
step_id=step.step_id if step else "",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except InvalidTaskStatusTransition:
|
||||||
|
LOG.warning(
|
||||||
|
"Invalid task status transition while failing a task",
|
||||||
|
task_id=task.task_id,
|
||||||
|
step_id=step.step_id if step else "",
|
||||||
|
)
|
||||||
|
return False
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"Failed to update status and failure reason in database. Task might going to be time_out",
|
"Failed to update status and failure reason in database. Task might going to be time_out",
|
||||||
@@ -426,6 +477,7 @@ class ForgeAgent:
|
|||||||
step_id=step.step_id if step else "",
|
step_id=step.step_id if step else "",
|
||||||
reason=reason,
|
reason=reason,
|
||||||
)
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
async def agent_step(
|
async def agent_step(
|
||||||
self,
|
self,
|
||||||
@@ -1294,6 +1346,11 @@ class ForgeAgent:
|
|||||||
extracted_information: dict[str, Any] | list | str | None = None,
|
extracted_information: dict[str, Any] | list | str | None = None,
|
||||||
failure_reason: str | None = None,
|
failure_reason: str | None = None,
|
||||||
) -> Task:
|
) -> Task:
|
||||||
|
# refresh task from db to get the latest status
|
||||||
|
task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
|
||||||
|
if task_from_db:
|
||||||
|
task = task_from_db
|
||||||
|
|
||||||
task.validate_update(status, extracted_information, failure_reason)
|
task.validate_update(status, extracted_information, failure_reason)
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if status is not None:
|
if status is not None:
|
||||||
|
|||||||
@@ -15,13 +15,15 @@ class StepStatus(StrEnum):
|
|||||||
running = "running"
|
running = "running"
|
||||||
failed = "failed"
|
failed = "failed"
|
||||||
completed = "completed"
|
completed = "completed"
|
||||||
|
canceled = "canceled"
|
||||||
|
|
||||||
def can_update_to(self, new_status: StepStatus) -> bool:
|
def can_update_to(self, new_status: StepStatus) -> bool:
|
||||||
allowed_transitions: dict[StepStatus, set[StepStatus]] = {
|
allowed_transitions: dict[StepStatus, set[StepStatus]] = {
|
||||||
StepStatus.created: {StepStatus.running},
|
StepStatus.created: {StepStatus.running, StepStatus.canceled},
|
||||||
StepStatus.running: {StepStatus.completed, StepStatus.failed},
|
StepStatus.running: {StepStatus.completed, StepStatus.failed, StepStatus.canceled},
|
||||||
StepStatus.failed: set(),
|
StepStatus.failed: set(),
|
||||||
StepStatus.completed: set(),
|
StepStatus.completed: set(),
|
||||||
|
StepStatus.canceled: set(),
|
||||||
}
|
}
|
||||||
return new_status in allowed_transitions[self]
|
return new_status in allowed_transitions[self]
|
||||||
|
|
||||||
|
|||||||
@@ -281,6 +281,22 @@ async def get_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.post("/tasks/{task_id}/cancel")
|
||||||
|
@base_router.post("/tasks/{task_id}/cancel/", include_in_schema=False)
|
||||||
|
async def cancel_task(
|
||||||
|
task_id: str,
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
) -> None:
|
||||||
|
analytics.capture("skyvern-oss-agent-task-get")
|
||||||
|
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||||
|
if not task_obj:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Task not found {task_id}",
|
||||||
|
)
|
||||||
|
await app.agent.update_task(task_obj, status=TaskStatus.canceled)
|
||||||
|
|
||||||
|
|
||||||
@base_router.post(
|
@base_router.post(
|
||||||
"/tasks/{task_id}/retry_webhook",
|
"/tasks/{task_id}/retry_webhook",
|
||||||
tags=["agent"],
|
tags=["agent"],
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled
|
||||||
|
|
||||||
|
|
||||||
class ProxyLocation(StrEnum):
|
class ProxyLocation(StrEnum):
|
||||||
US_CA = "US-CA"
|
US_CA = "US-CA"
|
||||||
@@ -80,6 +82,7 @@ class TaskStatus(StrEnum):
|
|||||||
failed = "failed"
|
failed = "failed"
|
||||||
terminated = "terminated"
|
terminated = "terminated"
|
||||||
completed = "completed"
|
completed = "completed"
|
||||||
|
canceled = "canceled"
|
||||||
|
|
||||||
def is_final(self) -> bool:
|
def is_final(self) -> bool:
|
||||||
return self in {
|
return self in {
|
||||||
@@ -87,6 +90,7 @@ class TaskStatus(StrEnum):
|
|||||||
TaskStatus.terminated,
|
TaskStatus.terminated,
|
||||||
TaskStatus.completed,
|
TaskStatus.completed,
|
||||||
TaskStatus.timed_out,
|
TaskStatus.timed_out,
|
||||||
|
TaskStatus.canceled,
|
||||||
}
|
}
|
||||||
|
|
||||||
def can_update_to(self, new_status: TaskStatus) -> bool:
|
def can_update_to(self, new_status: TaskStatus) -> bool:
|
||||||
@@ -96,22 +100,26 @@ class TaskStatus(StrEnum):
|
|||||||
TaskStatus.running,
|
TaskStatus.running,
|
||||||
TaskStatus.timed_out,
|
TaskStatus.timed_out,
|
||||||
TaskStatus.failed,
|
TaskStatus.failed,
|
||||||
|
TaskStatus.canceled,
|
||||||
},
|
},
|
||||||
TaskStatus.queued: {
|
TaskStatus.queued: {
|
||||||
TaskStatus.running,
|
TaskStatus.running,
|
||||||
TaskStatus.timed_out,
|
TaskStatus.timed_out,
|
||||||
TaskStatus.failed,
|
TaskStatus.failed,
|
||||||
|
TaskStatus.canceled,
|
||||||
},
|
},
|
||||||
TaskStatus.running: {
|
TaskStatus.running: {
|
||||||
TaskStatus.completed,
|
TaskStatus.completed,
|
||||||
TaskStatus.failed,
|
TaskStatus.failed,
|
||||||
TaskStatus.terminated,
|
TaskStatus.terminated,
|
||||||
TaskStatus.timed_out,
|
TaskStatus.timed_out,
|
||||||
|
TaskStatus.canceled,
|
||||||
},
|
},
|
||||||
TaskStatus.failed: set(),
|
TaskStatus.failed: set(),
|
||||||
TaskStatus.terminated: set(),
|
TaskStatus.terminated: set(),
|
||||||
TaskStatus.completed: set(),
|
TaskStatus.completed: set(),
|
||||||
TaskStatus.timed_out: set(),
|
TaskStatus.timed_out: set(),
|
||||||
|
TaskStatus.canceled: {TaskStatus.completed},
|
||||||
}
|
}
|
||||||
return new_status in allowed_transitions[self]
|
return new_status in allowed_transitions[self]
|
||||||
|
|
||||||
@@ -175,7 +183,9 @@ class Task(TaskRequest):
|
|||||||
old_status = self.status
|
old_status = self.status
|
||||||
|
|
||||||
if not old_status.can_update_to(status):
|
if not old_status.can_update_to(status):
|
||||||
raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}")
|
if old_status == TaskStatus.canceled:
|
||||||
|
raise TaskAlreadyCanceled(new_status=status, task_id=self.task_id)
|
||||||
|
raise InvalidTaskStatusTransition(old_status=old_status, new_status=status, task_id=self.task_id)
|
||||||
|
|
||||||
if status.requires_failure_reason() and failure_reason is None:
|
if status.requires_failure_reason() and failure_reason is None:
|
||||||
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")
|
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")
|
||||||
|
|||||||
Reference in New Issue
Block a user