From 87000f5cc34280923a17c7c2599c075461fca3f4 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Fri, 16 May 2025 02:39:47 -0700 Subject: [PATCH] shu/fix get run for workflow runs (#2362) --- skyvern/__init__.py | 5 ++-- skyvern/agent/__init__.py | 3 +-- skyvern/agent/agent.py | 29 +++++++++++++--------- skyvern/cli/commands.py | 3 +-- skyvern/forge/sdk/routes/agent_protocol.py | 4 +-- skyvern/services/run_service.py | 2 +- skyvern/services/workflow_service.py | 10 +++----- 7 files changed, 28 insertions(+), 28 deletions(-) diff --git a/skyvern/__init__.py b/skyvern/__init__.py index 84c82dbf..81048db1 100644 --- a/skyvern/__init__.py +++ b/skyvern/__init__.py @@ -14,7 +14,6 @@ setup_logger() from skyvern.forge import app # noqa: E402, F401 -from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402 -from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponseBase # noqa: E402 +from skyvern.agent import SkyvernAgent # noqa: E402 -__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponseBase"] +__all__ = ["SkyvernAgent"] diff --git a/skyvern/agent/__init__.py b/skyvern/agent/__init__.py index 10085206..4e071704 100644 --- a/skyvern/agent/__init__.py +++ b/skyvern/agent/__init__.py @@ -1,4 +1,3 @@ from skyvern.agent.agent import SkyvernAgent -from skyvern.agent.client import SkyvernClient -__all__ = ["SkyvernAgent", "SkyvernClient"] +__all__ = ["SkyvernAgent"] diff --git a/skyvern/agent/agent.py b/skyvern/agent/agent.py index aa7b0528..834b34dd 100644 --- a/skyvern/agent/agent.py +++ b/skyvern/agent/agent.py @@ -5,8 +5,8 @@ from typing import Any, cast from dotenv import load_dotenv -from skyvern.agent.client import SkyvernClient from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT +from skyvern.client import AsyncSkyvern from skyvern.config import settings from skyvern.forge import app from skyvern.forge.sdk.core import security, skyvern_context @@ -18,7 +18,7 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus -from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse +from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse from skyvern.services import run_service, task_v1_service, task_v2_service from skyvern.utils import migrate_db @@ -31,10 +31,8 @@ class SkyvernAgent: cdp_url: str | None = None, browser_path: str | None = None, browser_type: str | None = None, - extra_headers: dict[str, str] | None = None, ) -> None: - self.extra_headers = extra_headers - self.client: SkyvernClient | None = None + self.client: AsyncSkyvern | None = None if base_url is None and api_key is None: if not os.path.exists(".env"): raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.") @@ -69,10 +67,9 @@ class SkyvernAgent: settings.BROWSER_TYPE = browser_type elif base_url and api_key: - self.client = SkyvernClient( + self.client = AsyncSkyvern( base_url=base_url, api_key=api_key, - extra_headers=self.extra_headers, ) else: raise ValueError("base_url and api_key must be both provided") @@ -254,7 +251,12 @@ class SkyvernAgent: organization = await self.get_organization() return await run_service.get_run_response(run_id, organization_id=organization.organization_id) - return await self.client.get_run(run_id) + run_obj = await self.client.get_run(run_id) + if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]: + return TaskRunResponse.model_validate(run_obj.dict()) + elif run_obj.run_type == RunType.workflow_run: + return WorkflowRunResponse.model_validate(run_obj.dict()) + raise ValueError(f"Invalid run type: {run_obj.run_type}") async def run_task( self, @@ -272,6 +274,7 @@ class SkyvernAgent: wait_for_completion: bool = True, timeout: float = DEFAULT_AGENT_TIMEOUT, browser_session_id: str | None = None, + user_agent: str | None = None, ) -> TaskRunResponse: if not self.client: if engine == RunEngine.skyvern_v1: @@ -344,7 +347,7 @@ class SkyvernAgent: else: raise ValueError("Local mode is not supported for this method") - task_run = await self.client.run_task( + task_run = await self.client.agent.run_task( prompt=prompt, engine=engine, url=url, @@ -355,16 +358,17 @@ class SkyvernAgent: error_code_mapping=error_code_mapping, proxy_location=proxy_location, max_steps=max_steps, + user_agent=user_agent, ) if wait_for_completion: async with asyncio.timeout(timeout): while True: - task_run = await self.client.get_run(task_run.run_id) + task_run = await self.client.agent.get_run(task_run.run_id) if task_run.status.is_final(): - return task_run + break await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL) - return task_run + return TaskRunResponse.model_validate(task_run.dict()) async def run_workflow( self, @@ -380,5 +384,6 @@ class SkyvernAgent: wait_for_completion: bool = True, timeout: float = DEFAULT_AGENT_TIMEOUT, browser_session_id: str | None = None, + user_agent: str | None = None, ) -> None: raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.") diff --git a/skyvern/cli/commands.py b/skyvern/cli/commands.py index 657aae19..0d585f6b 100644 --- a/skyvern/cli/commands.py +++ b/skyvern/cli/commands.py @@ -49,9 +49,8 @@ async def skyvern_run_task(prompt: str, url: str) -> dict[str, str]: skyvern_agent = SkyvernAgent( base_url=settings.SKYVERN_BASE_URL, api_key=settings.SKYVERN_API_KEY, - extra_headers={"X-User-Agent": "skyvern-mcp"}, ) - res = await skyvern_agent.run_task(prompt=prompt, url=url) + res = await skyvern_agent.run_task(prompt=prompt, url=url, user_agent="skyvern-mcp") # TODO: It would be nice if we could return the task URL here output = res.model_dump()["output"] diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 7268b87c..21f0590b 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1549,7 +1549,7 @@ async def run_task( failure_reason=None, created_at=task_v2.created_at, modified_at=task_v2.modified_at, - app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}", + app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}", run_request=TaskRunRequest( engine=RunEngine.skyvern_v2, prompt=task_v2.prompt, @@ -1633,7 +1633,7 @@ async def run_workflow( run_request=workflow_run_request, downloaded_files=None, recording_url=None, - app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}", + app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}", ) diff --git a/skyvern/services/run_service.py b/skyvern/services/run_service.py index 86449119..af1692bb 100644 --- a/skyvern/services/run_service.py +++ b/skyvern/services/run_service.py @@ -75,7 +75,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R modified_at=task_v2.modified_at, recording_url=workflow_run.recording_url if workflow_run else None, downloaded_files=workflow_run.downloaded_files if workflow_run else None, - app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}", + app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}", run_request=TaskRunRequest( engine=RunEngine.skyvern_v2, prompt=task_v2.prompt, diff --git a/skyvern/services/workflow_service.py b/skyvern/services/workflow_service.py index 79899d16..75218781 100644 --- a/skyvern/services/workflow_service.py +++ b/skyvern/services/workflow_service.py @@ -73,9 +73,7 @@ async def get_workflow_run_response( workflow_run_id=workflow_run.workflow_run_id, organization_id=organization_id, ) - app_url = ( - f"{settings.SKYVERN_APP_URL.rstrip('/')}/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}" - ) + app_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}" return WorkflowRunResponse( run_id=workflow_run_id, run_type=RunType.workflow_run, @@ -88,12 +86,12 @@ async def get_workflow_run_response( created_at=workflow_run.created_at, modified_at=workflow_run.modified_at, run_request=WorkflowRunRequest( - workflow_id=workflow_run.workflow_id, + workflow_id=workflow_run.workflow_permanent_id, title=workflow_run_resp.workflow_title, parameters=workflow_run_resp.parameters, proxy_location=workflow_run.proxy_location, - webhook_url=workflow_run.webhook_callback_url, - totp_url=workflow_run.totp_verification_url, + webhook_url=workflow_run.webhook_callback_url or None, + totp_url=workflow_run.totp_verification_url or None, totp_identifier=workflow_run.totp_identifier, # TODO: add browser session id ),