shu/fix get run for workflow runs (#2362)

This commit is contained in:
Shuchang Zheng
2025-05-16 02:39:47 -07:00
committed by GitHub
parent 09ed1c1dff
commit 87000f5cc3
7 changed files with 28 additions and 28 deletions

View File

@@ -14,7 +14,6 @@ setup_logger()
from skyvern.forge import app # noqa: E402, F401
from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponseBase # noqa: E402
from skyvern.agent import SkyvernAgent # noqa: E402
__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponseBase"]
__all__ = ["SkyvernAgent"]

View File

@@ -1,4 +1,3 @@
from skyvern.agent.agent import SkyvernAgent
from skyvern.agent.client import SkyvernClient
__all__ = ["SkyvernAgent", "SkyvernClient"]
__all__ = ["SkyvernAgent"]

View File

@@ -5,8 +5,8 @@ from typing import Any, cast
from dotenv import load_dotenv
from skyvern.agent.client import SkyvernClient
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
from skyvern.client import AsyncSkyvern
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.core import security, skyvern_context
@@ -18,7 +18,7 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.utils import migrate_db
@@ -31,10 +31,8 @@ class SkyvernAgent:
cdp_url: str | None = None,
browser_path: str | None = None,
browser_type: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> None:
self.extra_headers = extra_headers
self.client: SkyvernClient | None = None
self.client: AsyncSkyvern | None = None
if base_url is None and api_key is None:
if not os.path.exists(".env"):
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
@@ -69,10 +67,9 @@ class SkyvernAgent:
settings.BROWSER_TYPE = browser_type
elif base_url and api_key:
self.client = SkyvernClient(
self.client = AsyncSkyvern(
base_url=base_url,
api_key=api_key,
extra_headers=self.extra_headers,
)
else:
raise ValueError("base_url and api_key must be both provided")
@@ -254,7 +251,12 @@ class SkyvernAgent:
organization = await self.get_organization()
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
return await self.client.get_run(run_id)
run_obj = await self.client.get_run(run_id)
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
return TaskRunResponse.model_validate(run_obj.dict())
elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj.dict())
raise ValueError(f"Invalid run type: {run_obj.run_type}")
async def run_task(
self,
@@ -272,6 +274,7 @@ class SkyvernAgent:
wait_for_completion: bool = True,
timeout: float = DEFAULT_AGENT_TIMEOUT,
browser_session_id: str | None = None,
user_agent: str | None = None,
) -> TaskRunResponse:
if not self.client:
if engine == RunEngine.skyvern_v1:
@@ -344,7 +347,7 @@ class SkyvernAgent:
else:
raise ValueError("Local mode is not supported for this method")
task_run = await self.client.run_task(
task_run = await self.client.agent.run_task(
prompt=prompt,
engine=engine,
url=url,
@@ -355,16 +358,17 @@ class SkyvernAgent:
error_code_mapping=error_code_mapping,
proxy_location=proxy_location,
max_steps=max_steps,
user_agent=user_agent,
)
if wait_for_completion:
async with asyncio.timeout(timeout):
while True:
task_run = await self.client.get_run(task_run.run_id)
task_run = await self.client.agent.get_run(task_run.run_id)
if task_run.status.is_final():
return task_run
break
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
return task_run
return TaskRunResponse.model_validate(task_run.dict())
async def run_workflow(
self,
@@ -380,5 +384,6 @@ class SkyvernAgent:
wait_for_completion: bool = True,
timeout: float = DEFAULT_AGENT_TIMEOUT,
browser_session_id: str | None = None,
user_agent: str | None = None,
) -> None:
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")

View File

@@ -49,9 +49,8 @@ async def skyvern_run_task(prompt: str, url: str) -> dict[str, str]:
skyvern_agent = SkyvernAgent(
base_url=settings.SKYVERN_BASE_URL,
api_key=settings.SKYVERN_API_KEY,
extra_headers={"X-User-Agent": "skyvern-mcp"},
)
res = await skyvern_agent.run_task(prompt=prompt, url=url)
res = await skyvern_agent.run_task(prompt=prompt, url=url, user_agent="skyvern-mcp")
# TODO: It would be nice if we could return the task URL here
output = res.model_dump()["output"]

View File

@@ -1549,7 +1549,7 @@ async def run_task(
failure_reason=None,
created_at=task_v2.created_at,
modified_at=task_v2.modified_at,
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
run_request=TaskRunRequest(
engine=RunEngine.skyvern_v2,
prompt=task_v2.prompt,
@@ -1633,7 +1633,7 @@ async def run_workflow(
run_request=workflow_run_request,
downloaded_files=None,
recording_url=None,
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}",
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}",
)

View File

@@ -75,7 +75,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
modified_at=task_v2.modified_at,
recording_url=workflow_run.recording_url if workflow_run else None,
downloaded_files=workflow_run.downloaded_files if workflow_run else None,
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
run_request=TaskRunRequest(
engine=RunEngine.skyvern_v2,
prompt=task_v2.prompt,

View File

@@ -73,9 +73,7 @@ async def get_workflow_run_response(
workflow_run_id=workflow_run.workflow_run_id,
organization_id=organization_id,
)
app_url = (
f"{settings.SKYVERN_APP_URL.rstrip('/')}/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}"
)
app_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}"
return WorkflowRunResponse(
run_id=workflow_run_id,
run_type=RunType.workflow_run,
@@ -88,12 +86,12 @@ async def get_workflow_run_response(
created_at=workflow_run.created_at,
modified_at=workflow_run.modified_at,
run_request=WorkflowRunRequest(
workflow_id=workflow_run.workflow_id,
workflow_id=workflow_run.workflow_permanent_id,
title=workflow_run_resp.workflow_title,
parameters=workflow_run_resp.parameters,
proxy_location=workflow_run.proxy_location,
webhook_url=workflow_run.webhook_callback_url,
totp_url=workflow_run.totp_verification_url,
webhook_url=workflow_run.webhook_callback_url or None,
totp_url=workflow_run.totp_verification_url or None,
totp_identifier=workflow_run.totp_identifier,
# TODO: add browser session id
),