update SkyvernClient using generated client code (#2044)
This commit is contained in:
@@ -1,10 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from skyvern.client.client import AsyncSkyvern
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import SkyvernClientException
|
||||
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse
|
||||
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse
|
||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse
|
||||
|
||||
|
||||
@@ -16,25 +12,38 @@ class SkyvernClient:
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.client = AsyncSkyvern(base_url=base_url, api_key=api_key)
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
goal: str,
|
||||
engine: RunEngine = RunEngine.skyvern_v1,
|
||||
prompt: str,
|
||||
url: str | None = None,
|
||||
title: str | None = None,
|
||||
engine: RunEngine = RunEngine.skyvern_v2,
|
||||
webhook_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
totp_url: str | None = None,
|
||||
title: str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
max_steps: int | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
publish_workflow: bool = False,
|
||||
) -> RunResponse:
|
||||
if engine == RunEngine.skyvern_v1:
|
||||
return RunResponse()
|
||||
elif engine == RunEngine.skyvern_v2:
|
||||
return RunResponse()
|
||||
raise ValueError(f"Invalid engine: {engine}")
|
||||
task_run_obj = await self.client.agent.run_task(
|
||||
goal=prompt,
|
||||
url=url,
|
||||
title=title,
|
||||
engine=engine,
|
||||
webhook_url=webhook_url,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_url=totp_url,
|
||||
error_code_mapping=error_code_mapping,
|
||||
proxy_location=proxy_location,
|
||||
max_steps=max_steps,
|
||||
browser_session_id=browser_session_id,
|
||||
publish_workflow=publish_workflow,
|
||||
)
|
||||
return RunResponse.model_validate(task_run_obj)
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
@@ -44,47 +53,24 @@ class SkyvernClient:
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
totp_url: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
template: bool = False,
|
||||
) -> RunWorkflowResponse:
|
||||
data: dict[str, Any] = {
|
||||
"webhook_callback_url": webhook_url,
|
||||
"proxy_location": proxy_location,
|
||||
"totp_identifier": totp_identifier,
|
||||
"totp_url": totp_url,
|
||||
}
|
||||
if workflow_input:
|
||||
data["data"] = workflow_input
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/workflows/{workflow_id}/run",
|
||||
headers={"x-api-key": self.api_key},
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise SkyvernClientException(
|
||||
f"Failed to run workflow: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return RunWorkflowResponse.model_validate(response.json())
|
||||
workflow_run_obj = await self.client.agent.run_workflow(
|
||||
workflow_id=workflow_id,
|
||||
data=workflow_input,
|
||||
webhook_callback_url=webhook_url,
|
||||
proxy_location=proxy_location,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_url=totp_url,
|
||||
browser_session_id=browser_session_id,
|
||||
template=template,
|
||||
)
|
||||
return RunWorkflowResponse.model_validate(workflow_run_obj)
|
||||
|
||||
async def get_run(
|
||||
self,
|
||||
run_id: str,
|
||||
) -> RunResponse:
|
||||
return RunResponse()
|
||||
|
||||
async def get_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
) -> WorkflowRunResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/workflows/runs/{workflow_run_id}",
|
||||
headers={"x-api-key": self.api_key},
|
||||
timeout=60,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise SkyvernClientException(
|
||||
f"Failed to get workflow run: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return WorkflowRunResponse.model_validate(response.json())
|
||||
run_obj = await self.client.agent.get_run(run_id=run_id)
|
||||
return RunResponse.model_validate(run_obj)
|
||||
|
||||
Reference in New Issue
Block a user