shu/test discriminated run response schema (#2046)

This commit is contained in:
Shuchang Zheng
2025-03-30 21:14:52 -07:00
committed by GitHub
parent 8253738c7b
commit ea67502161
6 changed files with 158 additions and 112 deletions

View File

@@ -2945,7 +2945,7 @@ class AgentDB:
task_run = (await session.scalars(query)).first()
return Run.model_validate(task_run) if task_run else None
async def get_task_run(
async def get_run(
self,
run_id: str,
organization_id: str | None = None,

View File

@@ -52,7 +52,7 @@ from skyvern.forge.sdk.schemas.tasks import (
TaskStatus,
)
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline
from skyvern.forge.sdk.services import org_auth_service, task_run_service
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.workflow.exceptions import (
FailedToCreateWorkflow,
FailedToUpdateWorkflow,
@@ -70,8 +70,8 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowStatus,
)
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest
from skyvern.services import task_v1_service, task_v2_service
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.webeye.actions.actions import Action
from skyvern.webeye.schemas import BrowserSessionResponse
@@ -441,7 +441,7 @@ async def get_runs(
return ORJSONResponse([run.model_dump() for run in runs])
@base_router.get(
@official_api_router.get(
"/runs/{run_id}",
tags=["agent"],
response_model=RunResponse,
@@ -450,7 +450,7 @@ async def get_runs(
"x-fern-sdk-method-name": "get_run",
},
)
@base_router.get(
@official_api_router.get(
"/runs/{run_id}/",
response_model=RunResponse,
include_in_schema=False,
@@ -459,15 +459,13 @@ async def get_run(
run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> RunResponse:
task_run_response = await task_run_service.get_task_run_response(
run_id, organization_id=current_org.organization_id
)
if not task_run_response:
run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id)
if not run_response:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task run not found {run_id}",
)
return task_run_response
return run_response
@base_router.get(
@@ -1511,7 +1509,7 @@ async def run_task(
run_request: TaskRunRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
) -> RunResponse:
) -> TaskRunResponse:
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
@@ -1554,24 +1552,27 @@ async def run_task(
background_tasks=background_tasks,
)
# build the task run response
return RunResponse(
return TaskRunResponse(
run_id=task_v1_response.task_id,
title=task_v1_response.title,
run_type=RunType.task_v1,
status=str(task_v1_response.status),
created_at=task_v1_response.created_at,
updated_at=task_v1_response.modified_at,
engine=RunEngine.skyvern_v1,
goal=task_v1_response.navigation_goal,
url=task_v1_response.url,
output=task_v1_response.extracted_information,
failure_reason=task_v1_response.failure_reason,
data_extraction_schema=task_v1_response.extracted_information_schema,
error_code_mapping=task_v1_response.error_code_mapping,
proxy_location=task_v1_response.proxy_location,
totp_identifier=task_v1_response.totp_identifier,
totp_url=task_v1_response.totp_verification_url,
webhook_url=task_v1_response.webhook_callback_url,
max_steps=task_v1_response.max_steps_per_run,
created_at=task_v1_response.created_at,
modified_at=task_v1_response.modified_at,
run_request=TaskRunRequest(
engine=RunEngine.skyvern_v1,
prompt=task_v1_response.navigation_goal,
url=task_v1_response.url,
webhook_url=task_v1_response.webhook_callback_url,
totp_identifier=task_v1_response.totp_identifier,
totp_url=task_v1_response.totp_verification_url,
proxy_location=task_v1_response.proxy_location,
max_steps=task_v1_response.max_steps_per_run,
data_extraction_schema=task_v1_response.extracted_information_schema,
error_code_mapping=task_v1_response.error_code_mapping,
browser_session_id=run_request.browser_session_id,
),
)
if run_request.engine == RunEngine.skyvern_v2:
# create task v2
@@ -1602,22 +1603,27 @@ async def run_task(
max_steps_override=run_request.max_steps,
browser_session_id=run_request.browser_session_id,
)
return RunResponse(
return TaskRunResponse(
run_id=task_v2.observer_cruise_id,
title=run_request.title,
run_type=RunType.task_v2,
status=str(task_v2.status),
engine=RunEngine.skyvern_v2,
goal=task_v2.prompt,
url=task_v2.url,
output=task_v2.output,
failure_reason=task_v2.failure_reason,
webhook_url=task_v2.webhook_callback_url,
totp_identifier=task_v2.totp_identifier,
totp_url=task_v2.totp_verification_url,
proxy_location=task_v2.proxy_location,
error_code_mapping=task_v2.error_code_mapping,
data_extraction_schema=task_v2.extracted_information_schema,
created_at=task_v2.created_at,
modified_at=task_v2.modified_at,
run_request=TaskRunRequest(
engine=RunEngine.skyvern_v2,
prompt=task_v2.prompt,
url=task_v2.url,
webhook_url=task_v2.webhook_callback_url,
totp_identifier=task_v2.totp_identifier,
totp_url=task_v2.totp_verification_url,
proxy_location=task_v2.proxy_location,
max_steps=run_request.max_steps,
browser_session_id=run_request.browser_session_id,
error_code_mapping=task_v2.error_code_mapping,
data_extraction_schema=task_v2.extracted_information_schema,
publish_workflow=run_request.publish_workflow,
),
)
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")

View File

@@ -1,54 +0,0 @@
from skyvern.forge import app
from skyvern.forge.sdk.schemas.runs import Run
from skyvern.schemas.runs import RunEngine, RunResponse, RunType
async def get_task_run(run_id: str, organization_id: str | None = None) -> Run | None:
return await app.DATABASE.get_task_run(run_id, organization_id=organization_id)
async def get_task_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
task_run = await get_task_run(run_id, organization_id=organization_id)
if not task_run:
return None
if task_run.task_run_type == RunType.task_v1:
# fetch task v1 from db and transform to task run response
task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id)
if not task_v1:
return None
return RunResponse(
run_id=task_run.run_id,
engine=RunEngine.skyvern_v1,
status=task_v1.status,
goal=task_v1.navigation_goal,
url=task_v1.url,
output=task_v1.extracted_information,
failure_reason=task_v1.failure_reason,
webhook_url=task_v1.webhook_callback_url,
totp_identifier=task_v1.totp_identifier,
totp_url=task_v1.totp_verification_url,
proxy_location=task_v1.proxy_location,
created_at=task_v1.created_at,
modified_at=task_v1.modified_at,
)
elif task_run.task_run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id)
if not task_v2:
return None
return RunResponse(
run_id=task_run.run_id,
engine=RunEngine.skyvern_v2,
status=task_v2.status,
goal=task_v2.prompt,
url=task_v2.url,
output=task_v2.output,
failure_reason=task_v2.failure_reason,
webhook_url=task_v2.webhook_callback_url,
totp_identifier=task_v2.totp_identifier,
totp_url=task_v2.totp_verification_url,
proxy_location=task_v2.proxy_location,
created_at=task_v2.created_at,
modified_at=task_v2.modified_at,
)
raise ValueError(f"Invalid task run type: {task_run.task_run_type}")