diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index b0db6f8e..7711fb6c 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -2905,3 +2905,15 @@ class AgentDB: query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc()) task_run = (await session.scalars(query)).first() return TaskRun.model_validate(task_run) if task_run else None + + async def get_task_run( + self, + run_id: str, + organization_id: str | None = None, + ) -> TaskRun | None: + async with self.Session() as session: + query = select(TaskRunModel).filter_by(run_id=run_id) + if organization_id: + query = query.filter_by(organization_id=organization_id) + task_run = (await session.scalars(query)).first() + return TaskRun.model_validate(task_run) if task_run else None diff --git a/skyvern/forge/sdk/db/id.py b/skyvern/forge/sdk/db/id.py index f78aec18..08ac1d61 100644 --- a/skyvern/forge/sdk/db/id.py +++ b/skyvern/forge/sdk/db/id.py @@ -37,7 +37,7 @@ BITWARDEN_SENSITIVE_INFORMATION_PARAMETER_PREFIX = "bsi" CREDENTIAL_PARAMETER_PREFIX = "cp" CREDENTIAL_PREFIX = "cred" ORGANIZATION_BITWARDEN_COLLECTION_PREFIX = "obc" -TASK_V2_ID = "oc" +TASK_V2_ID = "tsk_v2" THOUGHT_ID = "ot" ORGANIZATION_AUTH_TOKEN_PREFIX = "oat" ORG_PREFIX = "o" diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index c42d91f8..683f1230 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -45,7 +45,7 @@ from skyvern.forge.sdk.schemas.organizations import ( OrganizationUpdate, ) from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase -from skyvern.forge.sdk.schemas.task_runs import TaskRunType +from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse, TaskRunType from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request from skyvern.forge.sdk.schemas.tasks import ( CreateTaskResponse, @@ -57,7 +57,7 @@ from skyvern.forge.sdk.schemas.tasks import ( TaskStatus, ) from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline -from skyvern.forge.sdk.services import org_auth_service, task_v2_service +from skyvern.forge.sdk.services import org_auth_service, task_run_service, task_v2_service from skyvern.forge.sdk.workflow.exceptions import ( FailedToCreateWorkflow, FailedToUpdateWorkflow, @@ -391,6 +391,23 @@ async def get_runs( return ORJSONResponse([run.model_dump() for run in runs]) +@base_router.get("/runs/{run_id}", response_model=TaskRunResponse) +@base_router.get("/runs/{run_id}/", response_model=TaskRunResponse, include_in_schema=False) +async def get_run( + run_id: str, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> TaskRunResponse: + task_run_response = await task_run_service.get_task_run_response( + run_id, organization_id=current_org.organization_id + ) + if not task_run_response: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Task run not found {run_id}", + ) + return task_run_response + + @base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step]) @base_router.get( "/tasks/{task_id}/steps/", diff --git a/skyvern/forge/sdk/schemas/task_runs.py b/skyvern/forge/sdk/schemas/task_runs.py index bbb344af..216046c1 100644 --- a/skyvern/forge/sdk/schemas/task_runs.py +++ b/skyvern/forge/sdk/schemas/task_runs.py @@ -3,6 +3,24 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict +from skyvern.forge.sdk.schemas.tasks import ProxyLocation + + +class TaskRunStatus(StrEnum): + created = "created" + queued = "queued" + running = "running" + timed_out = "timed_out" + failed = "failed" + terminated = "terminated" + completed = "completed" + canceled = "canceled" + + +class RunEngine(StrEnum): + skyvern_v1 = "skyvern-1.0" + skyvern_v2 = "skyvern-2.0" + class TaskRunType(StrEnum): task_v1 = "task_v1" @@ -22,3 +40,22 @@ class TaskRun(BaseModel): cached: bool = False created_at: datetime modified_at: datetime + + +class TaskRunResponse(BaseModel): + run_id: str + engine: RunEngine = RunEngine.skyvern_v1 + status: TaskRunStatus + goal: str | None = None + url: str | None = None + output: dict | list | str | None = None + failure_reason: str | None = None + webhook_url: str | None = None + totp_identifier: str | None = None + totp_url: str | None = None + proxy_location: ProxyLocation | None = None + error_code_mapping: dict[str, str] | None = None + title: str | None = None + max_steps: int | None = None + created_at: datetime + modified_at: datetime diff --git a/skyvern/forge/sdk/services/task_run_service.py b/skyvern/forge/sdk/services/task_run_service.py new file mode 100644 index 00000000..64599a2d --- /dev/null +++ b/skyvern/forge/sdk/services/task_run_service.py @@ -0,0 +1,53 @@ +from skyvern.forge import app +from skyvern.forge.sdk.schemas.task_runs import RunEngine, TaskRun, TaskRunResponse, TaskRunType + + +async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None: + return await app.DATABASE.get_task_run(run_id, organization_id=organization_id) + + +async def get_task_run_response(run_id: str, organization_id: str | None = None) -> TaskRunResponse | None: + task_run = await get_task_run(run_id, organization_id=organization_id) + if not task_run: + return None + + if task_run.task_run_type == TaskRunType.task_v1: + # fetch task v1 from db and transform to task run response + task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id) + if not task_v1: + return None + return TaskRunResponse( + run_id=task_run.run_id, + engine=RunEngine.skyvern_v1, + status=task_v1.status, + goal=task_v1.navigation_goal, + url=task_v1.url, + output=task_v1.extracted_information, + failure_reason=task_v1.failure_reason, + webhook_url=task_v1.webhook_callback_url, + totp_identifier=task_v1.totp_identifier, + totp_url=task_v1.totp_verification_url, + proxy_location=task_v1.proxy_location, + created_at=task_v1.created_at, + modified_at=task_v1.modified_at, + ) + elif task_run.task_run_type == TaskRunType.task_v2: + task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id) + if not task_v2: + return None + return TaskRunResponse( + run_id=task_run.run_id, + engine=RunEngine.skyvern_v2, + status=task_v2.status, + goal=task_v2.prompt, + url=task_v2.url, + output=task_v2.output, + failure_reason=task_v2.failure_reason, + webhook_url=task_v2.webhook_callback_url, + totp_identifier=task_v2.totp_identifier, + totp_url=task_v2.totp_verification_url, + proxy_location=task_v2.proxy_location, + created_at=task_v2.created_at, + modified_at=task_v2.modified_at, + ) + raise ValueError(f"Invalid task run type: {task_run.task_run_type}")