shu/test discriminated run response schema (#2046)

This commit is contained in:
Shuchang Zheng
2025-03-30 21:14:52 -07:00
committed by GitHub
parent 8253738c7b
commit ea67502161
6 changed files with 158 additions and 112 deletions

View File

@@ -1,7 +1,6 @@
from skyvern.client.client import AsyncSkyvern from skyvern.client.client import AsyncSkyvern
from skyvern.config import settings from skyvern.config import settings
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse
class SkyvernClient: class SkyvernClient:
@@ -28,7 +27,7 @@ class SkyvernClient:
max_steps: int | None = None, max_steps: int | None = None,
browser_session_id: str | None = None, browser_session_id: str | None = None,
publish_workflow: bool = False, publish_workflow: bool = False,
) -> RunResponse: ) -> TaskRunResponse:
task_run_obj = await self.client.agent.run_task( task_run_obj = await self.client.agent.run_task(
goal=prompt, goal=prompt,
url=url, url=url,
@@ -43,7 +42,7 @@ class SkyvernClient:
browser_session_id=browser_session_id, browser_session_id=browser_session_id,
publish_workflow=publish_workflow, publish_workflow=publish_workflow,
) )
return RunResponse.model_validate(task_run_obj) return TaskRunResponse.model_validate(task_run_obj)
async def run_workflow( async def run_workflow(
self, self,
@@ -55,7 +54,7 @@ class SkyvernClient:
totp_url: str | None = None, totp_url: str | None = None,
browser_session_id: str | None = None, browser_session_id: str | None = None,
template: bool = False, template: bool = False,
) -> RunWorkflowResponse: ) -> WorkflowRunResponse:
workflow_run_obj = await self.client.agent.run_workflow( workflow_run_obj = await self.client.agent.run_workflow(
workflow_id=workflow_id, workflow_id=workflow_id,
data=workflow_input, data=workflow_input,
@@ -66,11 +65,15 @@ class SkyvernClient:
browser_session_id=browser_session_id, browser_session_id=browser_session_id,
template=template, template=template,
) )
return RunWorkflowResponse.model_validate(workflow_run_obj) return WorkflowRunResponse.model_validate(workflow_run_obj)
async def get_run( async def get_run(
self, self,
run_id: str, run_id: str,
) -> RunResponse: ) -> RunResponse:
run_obj = await self.client.agent.get_run(run_id=run_id) run_obj = await self.client.agent.get_run(run_id=run_id)
return RunResponse.model_validate(run_obj) if run_obj.run_type in [RunType.task_v1, RunType.task_v2]:
return TaskRunResponse.model_validate(run_obj)
elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj)
raise ValueError(f"Invalid run type: {run_obj.run_type}")

View File

@@ -2945,7 +2945,7 @@ class AgentDB:
task_run = (await session.scalars(query)).first() task_run = (await session.scalars(query)).first()
return Run.model_validate(task_run) if task_run else None return Run.model_validate(task_run) if task_run else None
async def get_task_run( async def get_run(
self, self,
run_id: str, run_id: str,
organization_id: str | None = None, organization_id: str | None = None,

View File

@@ -52,7 +52,7 @@ from skyvern.forge.sdk.schemas.tasks import (
TaskStatus, TaskStatus,
) )
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline
from skyvern.forge.sdk.services import org_auth_service, task_run_service from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.workflow.exceptions import ( from skyvern.forge.sdk.workflow.exceptions import (
FailedToCreateWorkflow, FailedToCreateWorkflow,
FailedToUpdateWorkflow, FailedToUpdateWorkflow,
@@ -70,8 +70,8 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowStatus, WorkflowStatus,
) )
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
from skyvern.services import task_v1_service, task_v2_service from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.actions import Action
from skyvern.webeye.schemas import BrowserSessionResponse from skyvern.webeye.schemas import BrowserSessionResponse
@@ -441,7 +441,7 @@ async def get_runs(
return ORJSONResponse([run.model_dump() for run in runs]) return ORJSONResponse([run.model_dump() for run in runs])
@base_router.get( @official_api_router.get(
"/runs/{run_id}", "/runs/{run_id}",
tags=["agent"], tags=["agent"],
response_model=RunResponse, response_model=RunResponse,
@@ -450,7 +450,7 @@ async def get_runs(
"x-fern-sdk-method-name": "get_run", "x-fern-sdk-method-name": "get_run",
}, },
) )
@base_router.get( @official_api_router.get(
"/runs/{run_id}/", "/runs/{run_id}/",
response_model=RunResponse, response_model=RunResponse,
include_in_schema=False, include_in_schema=False,
@@ -459,15 +459,13 @@ async def get_run(
run_id: str, run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
) -> RunResponse: ) -> RunResponse:
task_run_response = await task_run_service.get_task_run_response( run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id)
run_id, organization_id=current_org.organization_id if not run_response:
)
if not task_run_response:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task run not found {run_id}", detail=f"Task run not found {run_id}",
) )
return task_run_response return run_response
@base_router.get( @base_router.get(
@@ -1511,7 +1509,7 @@ async def run_task(
run_request: TaskRunRequest, run_request: TaskRunRequest,
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None, x_api_key: Annotated[str | None, Header()] = None,
) -> RunResponse: ) -> TaskRunResponse:
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url}) analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
@@ -1554,24 +1552,27 @@ async def run_task(
background_tasks=background_tasks, background_tasks=background_tasks,
) )
# build the task run response # build the task run response
return RunResponse( return TaskRunResponse(
run_id=task_v1_response.task_id, run_id=task_v1_response.task_id,
title=task_v1_response.title, run_type=RunType.task_v1,
status=str(task_v1_response.status), status=str(task_v1_response.status),
created_at=task_v1_response.created_at,
updated_at=task_v1_response.modified_at,
engine=RunEngine.skyvern_v1,
goal=task_v1_response.navigation_goal,
url=task_v1_response.url,
output=task_v1_response.extracted_information, output=task_v1_response.extracted_information,
failure_reason=task_v1_response.failure_reason, failure_reason=task_v1_response.failure_reason,
data_extraction_schema=task_v1_response.extracted_information_schema, created_at=task_v1_response.created_at,
error_code_mapping=task_v1_response.error_code_mapping, modified_at=task_v1_response.modified_at,
proxy_location=task_v1_response.proxy_location, run_request=TaskRunRequest(
totp_identifier=task_v1_response.totp_identifier, engine=RunEngine.skyvern_v1,
totp_url=task_v1_response.totp_verification_url, prompt=task_v1_response.navigation_goal,
webhook_url=task_v1_response.webhook_callback_url, url=task_v1_response.url,
max_steps=task_v1_response.max_steps_per_run, webhook_url=task_v1_response.webhook_callback_url,
totp_identifier=task_v1_response.totp_identifier,
totp_url=task_v1_response.totp_verification_url,
proxy_location=task_v1_response.proxy_location,
max_steps=task_v1_response.max_steps_per_run,
data_extraction_schema=task_v1_response.extracted_information_schema,
error_code_mapping=task_v1_response.error_code_mapping,
browser_session_id=run_request.browser_session_id,
),
) )
if run_request.engine == RunEngine.skyvern_v2: if run_request.engine == RunEngine.skyvern_v2:
# create task v2 # create task v2
@@ -1602,22 +1603,27 @@ async def run_task(
max_steps_override=run_request.max_steps, max_steps_override=run_request.max_steps,
browser_session_id=run_request.browser_session_id, browser_session_id=run_request.browser_session_id,
) )
return RunResponse( return TaskRunResponse(
run_id=task_v2.observer_cruise_id, run_id=task_v2.observer_cruise_id,
title=run_request.title, run_type=RunType.task_v2,
status=str(task_v2.status), status=str(task_v2.status),
engine=RunEngine.skyvern_v2,
goal=task_v2.prompt,
url=task_v2.url,
output=task_v2.output, output=task_v2.output,
failure_reason=task_v2.failure_reason, failure_reason=task_v2.failure_reason,
webhook_url=task_v2.webhook_callback_url,
totp_identifier=task_v2.totp_identifier,
totp_url=task_v2.totp_verification_url,
proxy_location=task_v2.proxy_location,
error_code_mapping=task_v2.error_code_mapping,
data_extraction_schema=task_v2.extracted_information_schema,
created_at=task_v2.created_at, created_at=task_v2.created_at,
modified_at=task_v2.modified_at, modified_at=task_v2.modified_at,
run_request=TaskRunRequest(
engine=RunEngine.skyvern_v2,
prompt=task_v2.prompt,
url=task_v2.url,
webhook_url=task_v2.webhook_callback_url,
totp_identifier=task_v2.totp_identifier,
totp_url=task_v2.totp_verification_url,
proxy_location=task_v2.proxy_location,
max_steps=run_request.max_steps,
browser_session_id=run_request.browser_session_id,
error_code_mapping=task_v2.error_code_mapping,
data_extraction_schema=task_v2.extracted_information_schema,
publish_workflow=run_request.publish_workflow,
),
) )
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}") raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")

View File

@@ -1,54 +0,0 @@
from skyvern.forge import app
from skyvern.forge.sdk.schemas.runs import Run
from skyvern.schemas.runs import RunEngine, RunResponse, RunType
async def get_task_run(run_id: str, organization_id: str | None = None) -> Run | None:
return await app.DATABASE.get_task_run(run_id, organization_id=organization_id)
async def get_task_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
task_run = await get_task_run(run_id, organization_id=organization_id)
if not task_run:
return None
if task_run.task_run_type == RunType.task_v1:
# fetch task v1 from db and transform to task run response
task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id)
if not task_v1:
return None
return RunResponse(
run_id=task_run.run_id,
engine=RunEngine.skyvern_v1,
status=task_v1.status,
goal=task_v1.navigation_goal,
url=task_v1.url,
output=task_v1.extracted_information,
failure_reason=task_v1.failure_reason,
webhook_url=task_v1.webhook_callback_url,
totp_identifier=task_v1.totp_identifier,
totp_url=task_v1.totp_verification_url,
proxy_location=task_v1.proxy_location,
created_at=task_v1.created_at,
modified_at=task_v1.modified_at,
)
elif task_run.task_run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id)
if not task_v2:
return None
return RunResponse(
run_id=task_run.run_id,
engine=RunEngine.skyvern_v2,
status=task_v2.status,
goal=task_v2.prompt,
url=task_v2.url,
output=task_v2.output,
failure_reason=task_v2.failure_reason,
webhook_url=task_v2.webhook_callback_url,
totp_identifier=task_v2.totp_identifier,
totp_url=task_v2.totp_verification_url,
proxy_location=task_v2.proxy_location,
created_at=task_v2.created_at,
modified_at=task_v2.modified_at,
)
raise ValueError(f"Invalid task run type: {task_run.task_run_type}")

View File

@@ -1,8 +1,9 @@
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Annotated, Any, Literal, Union
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from pydantic import BaseModel, field_validator from pydantic import BaseModel, Field, field_validator
from skyvern.utils.url_validators import validate_url from skyvern.utils.url_validators import validate_url
@@ -112,7 +113,7 @@ class RunStatus(StrEnum):
class TaskRunRequest(BaseModel): class TaskRunRequest(BaseModel):
goal: str prompt: str
url: str | None = None url: str | None = None
title: str | None = None title: str | None = None
engine: RunEngine = RunEngine.skyvern_v2 engine: RunEngine = RunEngine.skyvern_v2
@@ -135,21 +136,40 @@ class TaskRunRequest(BaseModel):
return validate_url(url) return validate_url(url)
class RunResponse(BaseModel): class WorkflowRunRequest(BaseModel):
title: str | None = None
parameters: dict[str, Any] | None = None
proxy_location: ProxyLocation | None = None
webhook_url: str | None = None
totp_url: str | None = None
totp_identifier: str | None = None
browser_session_id: str | None = None
@field_validator("webhook_url", "totp_url")
@classmethod
def validate_urls(cls, url: str | None) -> str | None:
if url is None:
return None
return validate_url(url)
class BaseRunResponse(BaseModel):
run_id: str run_id: str
engine: RunEngine = RunEngine.skyvern_v1
status: RunStatus status: RunStatus
goal: str | None = None
url: str | None = None
output: dict | list | str | None = None output: dict | list | str | None = None
failure_reason: str | None = None failure_reason: str | None = None
webhook_url: str | None = None
totp_identifier: str | None = None
totp_url: str | None = None
proxy_location: ProxyLocation | None = None
error_code_mapping: dict[str, str] | None = None
data_extraction_schema: dict | list | str | None = None
title: str | None = None
max_steps: int | None = None
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime
class TaskRunResponse(BaseRunResponse):
run_type: Literal[RunType.task_v1, RunType.task_v2]
run_request: TaskRunRequest | None = None
class WorkflowRunResponse(BaseRunResponse):
run_type: Literal[RunType.workflow_run]
run_request: WorkflowRunRequest | None = None
RunResponse = Annotated[Union[TaskRunResponse, WorkflowRunResponse], Field(discriminator="run_type")]

View File

@@ -0,0 +1,71 @@
from skyvern.forge import app
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
if not run:
return None
if run.task_run_type == RunType.task_v1:
# fetch task v1 from db and transform to task run response
task_v1 = await app.DATABASE.get_task(run.task_v1_id, organization_id=organization_id)
if not task_v1:
return None
return TaskRunResponse(
run_id=run.run_id,
run_type=run.task_run_type,
status=str(task_v1.status),
output=task_v1.extracted_information,
failure_reason=task_v1.failure_reason,
created_at=task_v1.created_at,
modified_at=task_v1.modified_at,
run_request=TaskRunRequest(
engine=RunEngine.skyvern_v1,
prompt=task_v1.navigation_goal,
url=task_v1.url,
webhook_url=task_v1.webhook_callback_url,
totp_identifier=task_v1.totp_identifier,
totp_url=task_v1.totp_verification_url,
proxy_location=task_v1.proxy_location,
max_steps=task_v1.max_steps_per_run,
data_extraction_schema=task_v1.extracted_information_schema,
error_code_mapping=task_v1.error_code_mapping,
),
)
elif run.task_run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(run.task_v2_id, organization_id=organization_id)
if not task_v2:
return None
return TaskRunResponse(
run_id=run.run_id,
run_type=run.task_run_type,
status=task_v2.status,
output=task_v2.output,
failure_reason=task_v2.failure_reason,
created_at=task_v2.created_at,
modified_at=task_v2.modified_at,
run_request=TaskRunRequest(
engine=RunEngine.skyvern_v2,
prompt=task_v2.prompt,
url=task_v2.url,
webhook_url=task_v2.webhook_callback_url,
totp_identifier=task_v2.totp_identifier,
totp_url=task_v2.totp_verification_url,
proxy_location=task_v2.proxy_location,
data_extraction_schema=task_v2.data_extraction_schema,
error_code_mapping=task_v2.error_code_mapping,
),
)
elif run.task_run_type == RunType.workflow_run:
raise NotImplementedError("Workflow run response not implemented")
# return WorkflowRunResponse(
# run_id=run.run_id,
# run_type=run.task_run_type,
# status=run.status,
# output=run.output,
# parameters=None,
# created_at=run.created_at,
# modified_at=run.modified_at,
# )
raise ValueError(f"Invalid task run type: {run.task_run_type}")