diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc38a2af..f97588b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,8 @@ repos: language_version: python3.11 exclude: | (?x)( - ^skyvern/client/.* + ^skyvern/client/.*| + ^skyvern/__init__.py ) - repo: https://github.com/pre-commit/pygrep-hooks diff --git a/skyvern/__init__.py b/skyvern/__init__.py index c3d97aac..43eb4242 100644 --- a/skyvern/__init__.py +++ b/skyvern/__init__.py @@ -13,6 +13,7 @@ tracer.configure( setup_logger() +from skyvern.forge import app # noqa: E402, F401 from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402 from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponse # noqa: E402 diff --git a/skyvern/agent/client.py b/skyvern/agent/client.py index 0d3e4c9d..1a5bd4f7 100644 --- a/skyvern/agent/client.py +++ b/skyvern/agent/client.py @@ -5,7 +5,7 @@ import httpx from skyvern.config import settings from skyvern.exceptions import SkyvernClientException from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse -from skyvern.schemas.runs import ProxyLocation, RunEngine, TaskRunResponse +from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse class SkyvernClient: @@ -29,11 +29,11 @@ class SkyvernClient: error_code_mapping: dict[str, str] | None = None, proxy_location: ProxyLocation | None = None, max_steps: int | None = None, - ) -> TaskRunResponse: + ) -> RunResponse: if engine == RunEngine.skyvern_v1: - return TaskRunResponse() + return RunResponse() elif engine == RunEngine.skyvern_v2: - return TaskRunResponse() + return RunResponse() raise ValueError(f"Invalid engine: {engine}") async def run_workflow( @@ -69,8 +69,8 @@ class SkyvernClient: async def get_run( self, run_id: str, - ) -> TaskRunResponse: - return TaskRunResponse() + ) -> RunResponse: + return RunResponse() async def get_workflow_run( self, diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 5e397752..c416d226 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -66,8 +66,8 @@ from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession +from skyvern.forge.sdk.schemas.runs import TaskRun from skyvern.forge.sdk.schemas.task_generations import TaskGeneration -from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus from skyvern.forge.sdk.schemas.totp_codes import TOTPCode @@ -91,7 +91,7 @@ from skyvern.forge.sdk.workflow.models.workflow import ( WorkflowRunStatus, WorkflowStatus, ) -from skyvern.schemas.runs import ProxyLocation +from skyvern.schemas.runs import ProxyLocation, RunType from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.models import AgentStepOutput @@ -2783,7 +2783,7 @@ class AgentDB: async def create_task_run( self, - task_run_type: TaskRunType, + task_run_type: RunType, organization_id: str, run_id: str, title: str | None = None, @@ -2931,7 +2931,7 @@ class AgentDB: raise NotFoundError(f"TaskRun {run_id} not found") async def get_cached_task_run( - self, task_run_type: TaskRunType, url_hash: str | None = None, organization_id: str | None = None + self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None ) -> TaskRun | None: async with self.Session() as session: query = select(TaskRunModel) diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index d36ad317..7766aa7f 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -41,7 +41,6 @@ from skyvern.forge.sdk.schemas.organizations import ( OrganizationUpdate, ) from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration -from skyvern.forge.sdk.schemas.task_runs import TaskRunType from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request from skyvern.forge.sdk.schemas.tasks import ( CreateTaskResponse, @@ -71,7 +70,7 @@ from skyvern.forge.sdk.workflow.models.workflow import ( WorkflowStatus, ) from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest -from skyvern.schemas.runs import RunEngine, TaskRunRequest, TaskRunResponse +from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest from skyvern.services import task_v1_service, task_v2_service from skyvern.webeye.actions.actions import Action from skyvern.webeye.schemas import BrowserSessionResponse @@ -445,7 +444,7 @@ async def get_runs( @base_router.get( "/runs/{run_id}", tags=["agent"], - response_model=TaskRunResponse, + response_model=RunResponse, openapi_extra={ "x-fern-sdk-group-name": "agent", "x-fern-sdk-method-name": "get_run", @@ -453,13 +452,13 @@ async def get_runs( ) @base_router.get( "/runs/{run_id}/", - response_model=TaskRunResponse, + response_model=RunResponse, include_in_schema=False, ) async def get_run( run_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), -) -> TaskRunResponse: +) -> RunResponse: task_run_response = await task_run_service.get_task_run_response( run_id, organization_id=current_org.organization_id ) @@ -683,7 +682,7 @@ async def run_workflow( version=version, ) await app.DATABASE.create_task_run( - task_run_type=TaskRunType.workflow_run, + task_run_type=RunType.workflow_run, organization_id=current_org.organization_id, run_id=workflow_run.workflow_run_id, title=workflow.title, @@ -1512,7 +1511,7 @@ async def run_task( run_request: TaskRunRequest, current_org: Organization = Depends(org_auth_service.get_current_org), x_api_key: Annotated[str | None, Header()] = None, -) -> TaskRunResponse: +) -> RunResponse: analytics.capture("skyvern-oss-run-task", data={"url": run_request.url}) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id) @@ -1555,7 +1554,7 @@ async def run_task( background_tasks=background_tasks, ) # build the task run response - return TaskRunResponse( + return RunResponse( run_id=task_v1_response.task_id, title=task_v1_response.title, status=str(task_v1_response.status), @@ -1603,7 +1602,7 @@ async def run_task( max_steps_override=run_request.max_steps, browser_session_id=run_request.browser_session_id, ) - return TaskRunResponse( + return RunResponse( run_id=task_v2.observer_cruise_id, title=run_request.title, status=str(task_v2.status), diff --git a/skyvern/forge/sdk/schemas/task_runs.py b/skyvern/forge/sdk/schemas/runs.py similarity index 68% rename from skyvern/forge/sdk/schemas/task_runs.py rename to skyvern/forge/sdk/schemas/runs.py index bbb344af..1361baa7 100644 --- a/skyvern/forge/sdk/schemas/task_runs.py +++ b/skyvern/forge/sdk/schemas/runs.py @@ -1,20 +1,15 @@ from datetime import datetime -from enum import StrEnum from pydantic import BaseModel, ConfigDict - -class TaskRunType(StrEnum): - task_v1 = "task_v1" - task_v2 = "task_v2" - workflow_run = "workflow_run" +from skyvern.schemas.runs import RunType class TaskRun(BaseModel): model_config = ConfigDict(from_attributes=True) task_run_id: str - task_run_type: TaskRunType + task_run_type: RunType run_id: str organization_id: str | None = None title: str | None = None diff --git a/skyvern/forge/sdk/services/task_run_service.py b/skyvern/forge/sdk/services/task_run_service.py index 71994314..21002270 100644 --- a/skyvern/forge/sdk/services/task_run_service.py +++ b/skyvern/forge/sdk/services/task_run_service.py @@ -1,23 +1,23 @@ from skyvern.forge import app -from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType -from skyvern.schemas.runs import RunEngine, TaskRunResponse +from skyvern.forge.sdk.schemas.runs import TaskRun +from skyvern.schemas.runs import RunEngine, RunResponse, RunType async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None: return await app.DATABASE.get_task_run(run_id, organization_id=organization_id) -async def get_task_run_response(run_id: str, organization_id: str | None = None) -> TaskRunResponse | None: +async def get_task_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None: task_run = await get_task_run(run_id, organization_id=organization_id) if not task_run: return None - if task_run.task_run_type == TaskRunType.task_v1: + if task_run.task_run_type == RunType.task_v1: # fetch task v1 from db and transform to task run response task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id) if not task_v1: return None - return TaskRunResponse( + return RunResponse( run_id=task_run.run_id, engine=RunEngine.skyvern_v1, status=task_v1.status, @@ -32,11 +32,11 @@ async def get_task_run_response(run_id: str, organization_id: str | None = None) created_at=task_v1.created_at, modified_at=task_v1.modified_at, ) - elif task_run.task_run_type == TaskRunType.task_v2: + elif task_run.task_run_type == RunType.task_v2: task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id) if not task_v2: return None - return TaskRunResponse( + return RunResponse( run_id=task_run.run_id, engine=RunEngine.skyvern_v2, status=task_v2.status, diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index bd0671e2..98c26b30 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -86,6 +86,12 @@ def get_tzinfo_from_proxy(proxy_location: ProxyLocation) -> ZoneInfo | None: return None +class RunType(StrEnum): + task_v1 = "task_v1" + task_v2 = "task_v2" + workflow_run = "workflow_run" + + class RunEngine(StrEnum): skyvern_v1 = "skyvern-1.0" skyvern_v2 = "skyvern-2.0" @@ -101,12 +107,15 @@ class TaskRunStatus(StrEnum): completed = "completed" canceled = "canceled" + def is_final(self) -> bool: + return self in [self.failed, self.terminated, self.canceled, self.timed_out, self.completed] + class TaskRunRequest(BaseModel): goal: str url: str | None = None title: str | None = None - engine: RunEngine = RunEngine.skyvern_v1 + engine: RunEngine = RunEngine.skyvern_v2 proxy_location: ProxyLocation | None = None data_extraction_schema: dict | list | str | None = None error_code_mapping: dict[str, str] | None = None @@ -126,7 +135,7 @@ class TaskRunRequest(BaseModel): return validate_url(url) -class TaskRunResponse(BaseModel): +class RunResponse(BaseModel): run_id: str engine: RunEngine = RunEngine.skyvern_v1 status: TaskRunStatus diff --git a/skyvern/services/task_v1_service.py b/skyvern/services/task_v1_service.py index 3c6e1c66..41e06bb8 100644 --- a/skyvern/services/task_v1_service.py +++ b/skyvern/services/task_v1_service.py @@ -12,8 +12,8 @@ from skyvern.forge.sdk.core.hashing import generate_url_hash from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase -from skyvern.forge.sdk.schemas.task_runs import TaskRunType from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest +from skyvern.schemas.runs import RunType LOG = structlog.get_logger() @@ -84,7 +84,7 @@ async def run_task( created_task = await app.agent.create_task(task, organization.organization_id) url_hash = generate_url_hash(task.url) await app.DATABASE.create_task_run( - task_run_type=TaskRunType.task_v1, + task_run_type=RunType.task_v1, organization_id=organization.organization_id, run_id=created_task.task_id, title=task.title, diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index efabf3db..2c1c808d 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -20,7 +20,6 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.schemas.organizations import Organization -from skyvern.forge.sdk.schemas.task_runs import TaskRunType from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline, WorkflowRunTimelineType from skyvern.forge.sdk.workflow.models.block import ( @@ -53,7 +52,7 @@ from skyvern.forge.sdk.workflow.models.yaml import ( WorkflowCreateYAMLRequest, WorkflowDefinitionYAML, ) -from skyvern.schemas.runs import ProxyLocation +from skyvern.schemas.runs import ProxyLocation, RunType from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website from skyvern.webeye.utils.page import SkyvernFrame @@ -196,7 +195,7 @@ async def initialize_task_v2( ) if create_task_run: await app.DATABASE.create_task_run( - task_run_type=TaskRunType.task_v2, + task_run_type=RunType.task_v2, organization_id=organization.organization_id, run_id=task_v2.observer_cruise_id, title=new_workflow.title,