shu/fix get run for workflow runs (#2362)
This commit is contained in:
@@ -14,7 +14,6 @@ setup_logger()
|
|||||||
|
|
||||||
|
|
||||||
from skyvern.forge import app # noqa: E402, F401
|
from skyvern.forge import app # noqa: E402, F401
|
||||||
from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402
|
from skyvern.agent import SkyvernAgent # noqa: E402
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponseBase # noqa: E402
|
|
||||||
|
|
||||||
__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponseBase"]
|
__all__ = ["SkyvernAgent"]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from skyvern.agent.agent import SkyvernAgent
|
from skyvern.agent.agent import SkyvernAgent
|
||||||
from skyvern.agent.client import SkyvernClient
|
|
||||||
|
|
||||||
__all__ = ["SkyvernAgent", "SkyvernClient"]
|
__all__ = ["SkyvernAgent"]
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import Any, cast
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from skyvern.agent.client import SkyvernClient
|
|
||||||
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
||||||
|
from skyvern.client import AsyncSkyvern
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.sdk.core import security, skyvern_context
|
from skyvern.forge.sdk.core import security, skyvern_context
|
||||||
@@ -18,7 +18,7 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
|
|||||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||||
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse
|
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
|
||||||
from skyvern.services import run_service, task_v1_service, task_v2_service
|
from skyvern.services import run_service, task_v1_service, task_v2_service
|
||||||
from skyvern.utils import migrate_db
|
from skyvern.utils import migrate_db
|
||||||
|
|
||||||
@@ -31,10 +31,8 @@ class SkyvernAgent:
|
|||||||
cdp_url: str | None = None,
|
cdp_url: str | None = None,
|
||||||
browser_path: str | None = None,
|
browser_path: str | None = None,
|
||||||
browser_type: str | None = None,
|
browser_type: str | None = None,
|
||||||
extra_headers: dict[str, str] | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.extra_headers = extra_headers
|
self.client: AsyncSkyvern | None = None
|
||||||
self.client: SkyvernClient | None = None
|
|
||||||
if base_url is None and api_key is None:
|
if base_url is None and api_key is None:
|
||||||
if not os.path.exists(".env"):
|
if not os.path.exists(".env"):
|
||||||
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
|
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
|
||||||
@@ -69,10 +67,9 @@ class SkyvernAgent:
|
|||||||
|
|
||||||
settings.BROWSER_TYPE = browser_type
|
settings.BROWSER_TYPE = browser_type
|
||||||
elif base_url and api_key:
|
elif base_url and api_key:
|
||||||
self.client = SkyvernClient(
|
self.client = AsyncSkyvern(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
extra_headers=self.extra_headers,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("base_url and api_key must be both provided")
|
raise ValueError("base_url and api_key must be both provided")
|
||||||
@@ -254,7 +251,12 @@ class SkyvernAgent:
|
|||||||
organization = await self.get_organization()
|
organization = await self.get_organization()
|
||||||
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
|
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
|
||||||
|
|
||||||
return await self.client.get_run(run_id)
|
run_obj = await self.client.get_run(run_id)
|
||||||
|
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
|
||||||
|
return TaskRunResponse.model_validate(run_obj.dict())
|
||||||
|
elif run_obj.run_type == RunType.workflow_run:
|
||||||
|
return WorkflowRunResponse.model_validate(run_obj.dict())
|
||||||
|
raise ValueError(f"Invalid run type: {run_obj.run_type}")
|
||||||
|
|
||||||
async def run_task(
|
async def run_task(
|
||||||
self,
|
self,
|
||||||
@@ -272,6 +274,7 @@ class SkyvernAgent:
|
|||||||
wait_for_completion: bool = True,
|
wait_for_completion: bool = True,
|
||||||
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
||||||
browser_session_id: str | None = None,
|
browser_session_id: str | None = None,
|
||||||
|
user_agent: str | None = None,
|
||||||
) -> TaskRunResponse:
|
) -> TaskRunResponse:
|
||||||
if not self.client:
|
if not self.client:
|
||||||
if engine == RunEngine.skyvern_v1:
|
if engine == RunEngine.skyvern_v1:
|
||||||
@@ -344,7 +347,7 @@ class SkyvernAgent:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Local mode is not supported for this method")
|
raise ValueError("Local mode is not supported for this method")
|
||||||
|
|
||||||
task_run = await self.client.run_task(
|
task_run = await self.client.agent.run_task(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
engine=engine,
|
engine=engine,
|
||||||
url=url,
|
url=url,
|
||||||
@@ -355,16 +358,17 @@ class SkyvernAgent:
|
|||||||
error_code_mapping=error_code_mapping,
|
error_code_mapping=error_code_mapping,
|
||||||
proxy_location=proxy_location,
|
proxy_location=proxy_location,
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
if wait_for_completion:
|
if wait_for_completion:
|
||||||
async with asyncio.timeout(timeout):
|
async with asyncio.timeout(timeout):
|
||||||
while True:
|
while True:
|
||||||
task_run = await self.client.get_run(task_run.run_id)
|
task_run = await self.client.agent.get_run(task_run.run_id)
|
||||||
if task_run.status.is_final():
|
if task_run.status.is_final():
|
||||||
return task_run
|
break
|
||||||
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
|
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
|
||||||
return task_run
|
return TaskRunResponse.model_validate(task_run.dict())
|
||||||
|
|
||||||
async def run_workflow(
|
async def run_workflow(
|
||||||
self,
|
self,
|
||||||
@@ -380,5 +384,6 @@ class SkyvernAgent:
|
|||||||
wait_for_completion: bool = True,
|
wait_for_completion: bool = True,
|
||||||
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
||||||
browser_session_id: str | None = None,
|
browser_session_id: str | None = None,
|
||||||
|
user_agent: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")
|
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")
|
||||||
|
|||||||
@@ -49,9 +49,8 @@ async def skyvern_run_task(prompt: str, url: str) -> dict[str, str]:
|
|||||||
skyvern_agent = SkyvernAgent(
|
skyvern_agent = SkyvernAgent(
|
||||||
base_url=settings.SKYVERN_BASE_URL,
|
base_url=settings.SKYVERN_BASE_URL,
|
||||||
api_key=settings.SKYVERN_API_KEY,
|
api_key=settings.SKYVERN_API_KEY,
|
||||||
extra_headers={"X-User-Agent": "skyvern-mcp"},
|
|
||||||
)
|
)
|
||||||
res = await skyvern_agent.run_task(prompt=prompt, url=url)
|
res = await skyvern_agent.run_task(prompt=prompt, url=url, user_agent="skyvern-mcp")
|
||||||
|
|
||||||
# TODO: It would be nice if we could return the task URL here
|
# TODO: It would be nice if we could return the task URL here
|
||||||
output = res.model_dump()["output"]
|
output = res.model_dump()["output"]
|
||||||
|
|||||||
@@ -1549,7 +1549,7 @@ async def run_task(
|
|||||||
failure_reason=None,
|
failure_reason=None,
|
||||||
created_at=task_v2.created_at,
|
created_at=task_v2.created_at,
|
||||||
modified_at=task_v2.modified_at,
|
modified_at=task_v2.modified_at,
|
||||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
|
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
|
||||||
run_request=TaskRunRequest(
|
run_request=TaskRunRequest(
|
||||||
engine=RunEngine.skyvern_v2,
|
engine=RunEngine.skyvern_v2,
|
||||||
prompt=task_v2.prompt,
|
prompt=task_v2.prompt,
|
||||||
@@ -1633,7 +1633,7 @@ async def run_workflow(
|
|||||||
run_request=workflow_run_request,
|
run_request=workflow_run_request,
|
||||||
downloaded_files=None,
|
downloaded_files=None,
|
||||||
recording_url=None,
|
recording_url=None,
|
||||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}",
|
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
|||||||
modified_at=task_v2.modified_at,
|
modified_at=task_v2.modified_at,
|
||||||
recording_url=workflow_run.recording_url if workflow_run else None,
|
recording_url=workflow_run.recording_url if workflow_run else None,
|
||||||
downloaded_files=workflow_run.downloaded_files if workflow_run else None,
|
downloaded_files=workflow_run.downloaded_files if workflow_run else None,
|
||||||
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
|
app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{task_v2.workflow_permanent_id}/{task_v2.workflow_run_id}",
|
||||||
run_request=TaskRunRequest(
|
run_request=TaskRunRequest(
|
||||||
engine=RunEngine.skyvern_v2,
|
engine=RunEngine.skyvern_v2,
|
||||||
prompt=task_v2.prompt,
|
prompt=task_v2.prompt,
|
||||||
|
|||||||
@@ -73,9 +73,7 @@ async def get_workflow_run_response(
|
|||||||
workflow_run_id=workflow_run.workflow_run_id,
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
app_url = (
|
app_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}"
|
||||||
f"{settings.SKYVERN_APP_URL.rstrip('/')}/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}"
|
|
||||||
)
|
|
||||||
return WorkflowRunResponse(
|
return WorkflowRunResponse(
|
||||||
run_id=workflow_run_id,
|
run_id=workflow_run_id,
|
||||||
run_type=RunType.workflow_run,
|
run_type=RunType.workflow_run,
|
||||||
@@ -88,12 +86,12 @@ async def get_workflow_run_response(
|
|||||||
created_at=workflow_run.created_at,
|
created_at=workflow_run.created_at,
|
||||||
modified_at=workflow_run.modified_at,
|
modified_at=workflow_run.modified_at,
|
||||||
run_request=WorkflowRunRequest(
|
run_request=WorkflowRunRequest(
|
||||||
workflow_id=workflow_run.workflow_id,
|
workflow_id=workflow_run.workflow_permanent_id,
|
||||||
title=workflow_run_resp.workflow_title,
|
title=workflow_run_resp.workflow_title,
|
||||||
parameters=workflow_run_resp.parameters,
|
parameters=workflow_run_resp.parameters,
|
||||||
proxy_location=workflow_run.proxy_location,
|
proxy_location=workflow_run.proxy_location,
|
||||||
webhook_url=workflow_run.webhook_callback_url,
|
webhook_url=workflow_run.webhook_callback_url or None,
|
||||||
totp_url=workflow_run.totp_verification_url,
|
totp_url=workflow_run.totp_verification_url or None,
|
||||||
totp_identifier=workflow_run.totp_identifier,
|
totp_identifier=workflow_run.totp_identifier,
|
||||||
# TODO: add browser session id
|
# TODO: add browser session id
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user