shu/fix get run for workflow runs (#2362)
This commit is contained in:
@@ -5,8 +5,8 @@ from typing import Any, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from skyvern.agent.client import SkyvernClient
|
||||
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
||||
from skyvern.client import AsyncSkyvern
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import security, skyvern_context
|
||||
@@ -18,7 +18,7 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse
|
||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
|
||||
from skyvern.services import run_service, task_v1_service, task_v2_service
|
||||
from skyvern.utils import migrate_db
|
||||
|
||||
@@ -31,10 +31,8 @@ class SkyvernAgent:
|
||||
cdp_url: str | None = None,
|
||||
browser_path: str | None = None,
|
||||
browser_type: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.extra_headers = extra_headers
|
||||
self.client: SkyvernClient | None = None
|
||||
self.client: AsyncSkyvern | None = None
|
||||
if base_url is None and api_key is None:
|
||||
if not os.path.exists(".env"):
|
||||
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
|
||||
@@ -69,10 +67,9 @@ class SkyvernAgent:
|
||||
|
||||
settings.BROWSER_TYPE = browser_type
|
||||
elif base_url and api_key:
|
||||
self.client = SkyvernClient(
|
||||
self.client = AsyncSkyvern(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
extra_headers=self.extra_headers,
|
||||
)
|
||||
else:
|
||||
raise ValueError("base_url and api_key must be both provided")
|
||||
@@ -254,7 +251,12 @@ class SkyvernAgent:
|
||||
organization = await self.get_organization()
|
||||
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
|
||||
|
||||
return await self.client.get_run(run_id)
|
||||
run_obj = await self.client.get_run(run_id)
|
||||
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
|
||||
return TaskRunResponse.model_validate(run_obj.dict())
|
||||
elif run_obj.run_type == RunType.workflow_run:
|
||||
return WorkflowRunResponse.model_validate(run_obj.dict())
|
||||
raise ValueError(f"Invalid run type: {run_obj.run_type}")
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
@@ -272,6 +274,7 @@ class SkyvernAgent:
|
||||
wait_for_completion: bool = True,
|
||||
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
||||
browser_session_id: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> TaskRunResponse:
|
||||
if not self.client:
|
||||
if engine == RunEngine.skyvern_v1:
|
||||
@@ -344,7 +347,7 @@ class SkyvernAgent:
|
||||
else:
|
||||
raise ValueError("Local mode is not supported for this method")
|
||||
|
||||
task_run = await self.client.run_task(
|
||||
task_run = await self.client.agent.run_task(
|
||||
prompt=prompt,
|
||||
engine=engine,
|
||||
url=url,
|
||||
@@ -355,16 +358,17 @@ class SkyvernAgent:
|
||||
error_code_mapping=error_code_mapping,
|
||||
proxy_location=proxy_location,
|
||||
max_steps=max_steps,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
if wait_for_completion:
|
||||
async with asyncio.timeout(timeout):
|
||||
while True:
|
||||
task_run = await self.client.get_run(task_run.run_id)
|
||||
task_run = await self.client.agent.get_run(task_run.run_id)
|
||||
if task_run.status.is_final():
|
||||
return task_run
|
||||
break
|
||||
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
|
||||
return task_run
|
||||
return TaskRunResponse.model_validate(task_run.dict())
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
@@ -380,5 +384,6 @@ class SkyvernAgent:
|
||||
wait_for_completion: bool = True,
|
||||
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
||||
browser_session_id: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")
|
||||
|
||||
Reference in New Issue
Block a user