shu/fix get run for workflow runs (#2362)

This commit is contained in:
Shuchang Zheng
2025-05-16 02:39:47 -07:00
committed by GitHub
parent 09ed1c1dff
commit 87000f5cc3
7 changed files with 28 additions and 28 deletions

View File

@@ -5,8 +5,8 @@ from typing import Any, cast
from dotenv import load_dotenv
from skyvern.agent.client import SkyvernClient
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
from skyvern.client import AsyncSkyvern
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.core import security, skyvern_context
@@ -18,7 +18,7 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.utils import migrate_db
@@ -31,10 +31,8 @@ class SkyvernAgent:
cdp_url: str | None = None,
browser_path: str | None = None,
browser_type: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> None:
self.extra_headers = extra_headers
self.client: SkyvernClient | None = None
self.client: AsyncSkyvern | None = None
if base_url is None and api_key is None:
if not os.path.exists(".env"):
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
@@ -69,10 +67,9 @@ class SkyvernAgent:
settings.BROWSER_TYPE = browser_type
elif base_url and api_key:
self.client = SkyvernClient(
self.client = AsyncSkyvern(
base_url=base_url,
api_key=api_key,
extra_headers=self.extra_headers,
)
else:
raise ValueError("base_url and api_key must be both provided")
@@ -254,7 +251,12 @@ class SkyvernAgent:
organization = await self.get_organization()
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
return await self.client.get_run(run_id)
run_obj = await self.client.get_run(run_id)
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
return TaskRunResponse.model_validate(run_obj.dict())
elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj.dict())
raise ValueError(f"Invalid run type: {run_obj.run_type}")
async def run_task(
self,
@@ -272,6 +274,7 @@ class SkyvernAgent:
wait_for_completion: bool = True,
timeout: float = DEFAULT_AGENT_TIMEOUT,
browser_session_id: str | None = None,
user_agent: str | None = None,
) -> TaskRunResponse:
if not self.client:
if engine == RunEngine.skyvern_v1:
@@ -344,7 +347,7 @@ class SkyvernAgent:
else:
raise ValueError("Local mode is not supported for this method")
task_run = await self.client.run_task(
task_run = await self.client.agent.run_task(
prompt=prompt,
engine=engine,
url=url,
@@ -355,16 +358,17 @@ class SkyvernAgent:
error_code_mapping=error_code_mapping,
proxy_location=proxy_location,
max_steps=max_steps,
user_agent=user_agent,
)
if wait_for_completion:
async with asyncio.timeout(timeout):
while True:
task_run = await self.client.get_run(task_run.run_id)
task_run = await self.client.agent.get_run(task_run.run_id)
if task_run.status.is_final():
return task_run
break
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
return task_run
return TaskRunResponse.model_validate(task_run.dict())
async def run_workflow(
self,
@@ -380,5 +384,6 @@ class SkyvernAgent:
wait_for_completion: bool = True,
timeout: float = DEFAULT_AGENT_TIMEOUT,
browser_session_id: str | None = None,
user_agent: str | None = None,
) -> None:
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")