shu/fix get run for workflow runs (#2362)
This commit is contained in:
@@ -14,7 +14,6 @@ setup_logger()
|
||||
|
||||
|
||||
from skyvern.forge import app # noqa: E402, F401
|
||||
from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponseBase # noqa: E402
|
||||
from skyvern.agent import SkyvernAgent # noqa: E402
|
||||
|
||||
__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponseBase"]
|
||||
__all__ = ["SkyvernAgent"]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from skyvern.agent.agent import SkyvernAgent
|
||||
from skyvern.agent.client import SkyvernClient
|
||||
|
||||
__all__ = ["SkyvernAgent", "SkyvernClient"]
|
||||
__all__ = ["SkyvernAgent"]
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from skyvern.agent.client import SkyvernClient
|
||||
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
||||
from skyvern.client import AsyncSkyvern
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import security, skyvern_context
|
||||
@@ -18,7 +18,7 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse
|
||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
|
||||
from skyvern.services import run_service, task_v1_service, task_v2_service
|
||||
from skyvern.utils import migrate_db
|
||||
|
||||
@@ -31,10 +31,8 @@ class SkyvernAgent:
|
||||
cdp_url: str | None = None,
|
||||
browser_path: str | None = None,
|
||||
browser_type: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.extra_headers = extra_headers
|
||||
self.client: SkyvernClient | None = None
|
||||
self.client: AsyncSkyvern | None = None
|
||||
if base_url is None and api_key is None:
|
||||
if not os.path.exists(".env"):
|
||||
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
|
||||
@@ -69,10 +67,9 @@ class SkyvernAgent:
|
||||
|
||||
settings.BROWSER_TYPE = browser_type
|
||||
elif base_url and api_key:
|
||||
self.client = SkyvernClient(
|
||||
self.client = AsyncSkyvern(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
extra_headers=self.extra_headers,
|
||||
)
|
||||
else:
|
||||
raise ValueError("base_url and api_key must be both provided")
|
||||
@@ -254,7 +251,12 @@ class SkyvernAgent:
|
||||
organization = await self.get_organization()
|
||||
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
|
||||
|
||||
return await self.client.get_run(run_id)
|
||||
run_obj = await self.client.get_run(run_id)
|
||||
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
|
||||
return TaskRunResponse.model_validate(run_obj.dict())
|
||||
elif run_obj.run_type == RunType.workflow_run:
|
||||
return WorkflowRunResponse.model_validate(run_obj.dict())
|
||||
raise ValueError(f"Invalid run type: {run_obj.run_type}")
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
@@ -272,6 +274,7 @@ class SkyvernAgent:
|
||||
wait_for_completion: bool = True,
|
||||
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
||||
browser_session_id: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> TaskRunResponse:
|
||||
if not self.client:
|
||||
if engine == RunEngine.skyvern_v1:
|
||||
@@ -344,7 +347,7 @@ class SkyvernAgent:
|
||||
else:
|
||||
raise ValueError("Local mode is not supported for this method")
|
||||
|
||||
task_run = await self.client.run_task(
|
||||
task_run = await self.client.agent.run_task(
|
||||
prompt=prompt,
|
||||
engine=engine,
|
||||
url=url,
|
||||
@@ -355,16 +358,17 @@ class SkyvernAgent:
|
||||
error_code_mapping=error_code_mapping,
|
||||
proxy_location=proxy_location,
|
||||
max_steps=max_steps,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
if wait_for_completion:
|
||||
async with asyncio.timeout(timeout):
|
||||
while True:
|
||||
task_run = await self.client.get_run(task_run.run_id)
|
||||
task_run = await self.client.agent.get_run(task_run.run_id)
|
||||
if task_run.status.is_final():
|
||||
return task_run
|
||||
break
|
||||
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
|
||||
return task_run
|
||||
return TaskRunResponse.model_validate(task_run.dict())
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
@@ -380,5 +384,6 @@ class SkyvernAgent:
|
||||
wait_for_completion: bool = True,
|
||||
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
||||
browser_session_id: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")
|
||||
|
||||
@@ -49,9 +49,8 @@ async def skyvern_run_task(prompt: str, url: str) -> dict[str, str]:
|
||||
skyvern_agent = SkyvernAgent(
|
||||
base_url=settings.SKYVERN_BASE_URL,
|
||||
api_key=settings.SKYVERN_API_KEY,
|
||||
extra_headers={"X-User-Agent": "skyvern-mcp"},
|
||||
)
|
||||
res = await skyvern_agent.run_task(prompt=prompt, url=url)
|
||||
res = await skyvern_agent.run_task(prompt=prompt, url=url, user_agent="skyvern-mcp")
|
||||
|
||||
# TODO: It would be nice if we could return the task URL here
|
||||
output = res.model_dump()["output"]
|
||||
|
||||
@@ -1549,7 +1549,7 @@ async def run_task(
|
||||
failure_reason=None,
|
||||
created_at=task_v2.created_at,
|
||||
modified_at=task_v2.modified_at,
|
||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
|
||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
|
||||
run_request=TaskRunRequest(
|
||||
engine=RunEngine.skyvern_v2,
|
||||
prompt=task_v2.prompt,
|
||||
@@ -1633,7 +1633,7 @@ async def run_workflow(
|
||||
run_request=workflow_run_request,
|
||||
downloaded_files=None,
|
||||
recording_url=None,
|
||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}",
|
||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
||||
modified_at=task_v2.modified_at,
|
||||
recording_url=workflow_run.recording_url if workflow_run else None,
|
||||
downloaded_files=workflow_run.downloaded_files if workflow_run else None,
|
||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
|
||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
|
||||
run_request=TaskRunRequest(
|
||||
engine=RunEngine.skyvern_v2,
|
||||
prompt=task_v2.prompt,
|
||||
|
||||
@@ -73,9 +73,7 @@ async def get_workflow_run_response(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
app_url = (
|
||||
f"{settings.SKYVERN_APP_URL.rstrip('/')}/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}"
|
||||
)
|
||||
app_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}"
|
||||
return WorkflowRunResponse(
|
||||
run_id=workflow_run_id,
|
||||
run_type=RunType.workflow_run,
|
||||
@@ -88,12 +86,12 @@ async def get_workflow_run_response(
|
||||
created_at=workflow_run.created_at,
|
||||
modified_at=workflow_run.modified_at,
|
||||
run_request=WorkflowRunRequest(
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_id=workflow_run.workflow_permanent_id,
|
||||
title=workflow_run_resp.workflow_title,
|
||||
parameters=workflow_run_resp.parameters,
|
||||
proxy_location=workflow_run.proxy_location,
|
||||
webhook_url=workflow_run.webhook_callback_url,
|
||||
totp_url=workflow_run.totp_verification_url,
|
||||
webhook_url=workflow_run.webhook_callback_url or None,
|
||||
totp_url=workflow_run.totp_verification_url or None,
|
||||
totp_identifier=workflow_run.totp_identifier,
|
||||
# TODO: add browser session id
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user