Move the code over from private repository (#3)
This commit is contained in:
181
skyvern/forge/sdk/schemas/tasks.py
Normal file
181
skyvern/forge/sdk/schemas/tasks.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProxyLocation(StrEnum):
|
||||
US_CA = "US-CA"
|
||||
US_NY = "US-NY"
|
||||
US_TX = "US-TX"
|
||||
US_FL = "US-FL"
|
||||
US_WA = "US-WA"
|
||||
RESIDENTIAL = "RESIDENTIAL"
|
||||
NONE = "NONE"
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
url: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Starting URL for the task.",
|
||||
examples=["https://www.geico.com"],
|
||||
)
|
||||
# TODO: use HttpUrl instead of str
|
||||
webhook_callback_url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL to call when the task is completed.",
|
||||
examples=["https://my-webhook.com"],
|
||||
)
|
||||
navigation_goal: str | None = Field(
|
||||
default=None,
|
||||
description="The user's goal for the task.",
|
||||
examples=["Get a quote for car insurance"],
|
||||
)
|
||||
data_extraction_goal: str | None = Field(
|
||||
default=None,
|
||||
description="The user's goal for data extraction.",
|
||||
examples=["Extract the quote price"],
|
||||
)
|
||||
navigation_payload: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The user's details needed to achieve the task.",
|
||||
examples=[{"name": "John Doe", "email": "john@doe.com"}],
|
||||
)
|
||||
proxy_location: ProxyLocation | None = Field(
|
||||
None,
|
||||
description="The location of the proxy to use for the task.",
|
||||
examples=["US-WA", "US-CA", "US-FL", "US-NY", "US-TX"],
|
||||
)
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The requested schema of the extracted information.",
|
||||
)
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
terminated = "terminated"
|
||||
completed = "completed"
|
||||
|
||||
def is_final(self) -> bool:
|
||||
return self in {TaskStatus.failed, TaskStatus.terminated, TaskStatus.completed}
|
||||
|
||||
def can_update_to(self, new_status: TaskStatus) -> bool:
|
||||
allowed_transitions: dict[TaskStatus, set[TaskStatus]] = {
|
||||
TaskStatus.created: {TaskStatus.running},
|
||||
TaskStatus.running: {TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated},
|
||||
TaskStatus.failed: set(),
|
||||
TaskStatus.completed: set(),
|
||||
}
|
||||
return new_status in allowed_transitions[self]
|
||||
|
||||
def requires_extracted_info(self) -> bool:
|
||||
status_requires_extracted_information = {TaskStatus.completed}
|
||||
return self in status_requires_extracted_information
|
||||
|
||||
def cant_have_extracted_info(self) -> bool:
|
||||
status_cant_have_extracted_information = {
|
||||
TaskStatus.created,
|
||||
TaskStatus.running,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.terminated,
|
||||
}
|
||||
return self in status_cant_have_extracted_information
|
||||
|
||||
def requires_failure_reason(self) -> bool:
|
||||
status_requires_failure_reason = {TaskStatus.failed, TaskStatus.terminated}
|
||||
return self in status_requires_failure_reason
|
||||
|
||||
|
||||
class Task(TaskRequest):
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="The creation datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
)
|
||||
modified_at: datetime = Field(
|
||||
...,
|
||||
description="The modification datetime of the task.",
|
||||
examples=["2023-01-01T00:00:00Z"],
|
||||
)
|
||||
task_id: str = Field(
|
||||
...,
|
||||
description="The ID of the task.",
|
||||
examples=["50da533e-3904-4401-8a07-c49adf88b5eb"],
|
||||
)
|
||||
status: TaskStatus = Field(..., description="The status of the task.", examples=["created"])
|
||||
extracted_information: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
description="The extracted information from the task.",
|
||||
)
|
||||
failure_reason: str | None = Field(
|
||||
None,
|
||||
description="The reason for the task failure.",
|
||||
)
|
||||
organization_id: str | None = None
|
||||
workflow_run_id: str | None = None
|
||||
order: int | None = None
|
||||
retry: int | None = None
|
||||
|
||||
def validate_update(
|
||||
self,
|
||||
status: TaskStatus,
|
||||
extracted_information: dict[str, Any] | list | str | None,
|
||||
failure_reason: str | None = None,
|
||||
) -> None:
|
||||
old_status = self.status
|
||||
|
||||
if not old_status.can_update_to(status):
|
||||
raise ValueError(f"invalid_status_transition({old_status},{status},{self.task_id}")
|
||||
|
||||
if status.requires_failure_reason() and failure_reason is None:
|
||||
raise ValueError(f"status_requires_failure_reason({status},{self.task_id}")
|
||||
|
||||
if status.requires_extracted_info() and self.data_extraction_goal and extracted_information is None:
|
||||
raise ValueError(f"status_requires_extracted_information({status},{self.task_id}")
|
||||
|
||||
if status.cant_have_extracted_info() and extracted_information is not None:
|
||||
raise ValueError(f"status_cant_have_extracted_information({self.task_id})")
|
||||
|
||||
if self.extracted_information is not None and extracted_information is not None:
|
||||
raise ValueError(f"cant_override_extracted_information({self.task_id})")
|
||||
|
||||
if self.failure_reason is not None and failure_reason is not None:
|
||||
raise ValueError(f"cant_override_failure_reason({self.task_id})")
|
||||
|
||||
def to_task_response(
|
||||
self, screenshot_url: str | None = None, recording_url: str | None = None, failure_reason: str | None = None
|
||||
) -> TaskResponse:
|
||||
return TaskResponse(
|
||||
request=self,
|
||||
task_id=self.task_id,
|
||||
status=self.status,
|
||||
created_at=self.created_at,
|
||||
modified_at=self.modified_at,
|
||||
extracted_information=self.extracted_information,
|
||||
failure_reason=failure_reason or self.failure_reason,
|
||||
screenshot_url=screenshot_url,
|
||||
recording_url=recording_url,
|
||||
)
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
request: TaskRequest
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
extracted_information: list | dict[str, Any] | str | None = None
|
||||
screenshot_url: str | None = None
|
||||
recording_url: str | None = None
|
||||
failure_reason: str | None = None
|
||||
|
||||
|
||||
class CreateTaskResponse(BaseModel):
|
||||
task_id: str
|
||||
Reference in New Issue
Block a user