shu/test discriminated run response schema (#2046)

This commit is contained in:
Shuchang Zheng
2025-03-30 21:14:52 -07:00
committed by GitHub
parent 8253738c7b
commit ea67502161
6 changed files with 158 additions and 112 deletions

View File

@@ -1,7 +1,6 @@
from skyvern.client.client import AsyncSkyvern
from skyvern.config import settings
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
class SkyvernClient:
@@ -28,7 +27,7 @@ class SkyvernClient:
max_steps: int | None = None,
browser_session_id: str | None = None,
publish_workflow: bool = False,
) -> RunResponse:
) -> TaskRunResponse:
task_run_obj = await self.client.agent.run_task(
goal=prompt,
url=url,
@@ -43,7 +42,7 @@ class SkyvernClient:
browser_session_id=browser_session_id,
publish_workflow=publish_workflow,
)
return RunResponse.model_validate(task_run_obj)
return TaskRunResponse.model_validate(task_run_obj)
async def run_workflow(
self,
@@ -55,7 +54,7 @@ class SkyvernClient:
totp_url: str | None = None,
browser_session_id: str | None = None,
template: bool = False,
) -> RunWorkflowResponse:
) -> WorkflowRunResponse:
workflow_run_obj = await self.client.agent.run_workflow(
workflow_id=workflow_id,
data=workflow_input,
@@ -66,11 +65,15 @@ class SkyvernClient:
browser_session_id=browser_session_id,
template=template,
)
return RunWorkflowResponse.model_validate(workflow_run_obj)
return WorkflowRunResponse.model_validate(workflow_run_obj)
async def get_run(
self,
run_id: str,
) -> RunResponse:
run_obj = await self.client.agent.get_run(run_id=run_id)
return RunResponse.model_validate(run_obj)
if run_obj.run_type in [RunType.task_v1, RunType.task_v2]:
return TaskRunResponse.model_validate(run_obj)
elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj)
raise ValueError(f"Invalid run type: {run_obj.run_type}")