migrate library facing code to one interface from skyvern import Skyvern (#2368)

This commit is contained in:
Shuchang Zheng
2025-05-16 17:55:46 -07:00
committed by GitHub
parent 1da95bee93
commit 89fd604022
11 changed files with 92 additions and 152 deletions

View File

@@ -14,6 +14,6 @@ setup_logger()
from skyvern.forge import app # noqa: E402, F401 from skyvern.forge import app # noqa: E402, F401
from skyvern.agent import SkyvernAgent # noqa: E402 from skyvern.library import Skyvern # noqa: E402
__all__ = ["SkyvernAgent"] __all__ = ["Skyvern"]

View File

@@ -1,3 +0,0 @@
from skyvern.agent.agent import SkyvernAgent
__all__ = ["SkyvernAgent"]

View File

@@ -1,88 +0,0 @@
from skyvern.client.client import AsyncSkyvern
from skyvern.config import settings
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
class SkyvernClient:
def __init__(
self,
base_url: str = settings.SKYVERN_BASE_URL,
api_key: str = settings.SKYVERN_API_KEY,
extra_headers: dict[str, str] | None = None,
) -> None:
self.base_url = base_url
self.api_key = api_key
self.client = AsyncSkyvern(base_url=base_url, api_key=api_key)
self.extra_headers = extra_headers or {}
self.user_agent = None
if "X-User-Agent" in self.extra_headers:
self.user_agent = self.extra_headers["X-User-Agent"]
elif "x-user-agent" in self.extra_headers:
self.user_agent = self.extra_headers["x-user-agent"]
async def run_task(
self,
prompt: str,
url: str | None = None,
title: str | None = None,
engine: RunEngine = RunEngine.skyvern_v2,
webhook_url: str | None = None,
totp_identifier: str | None = None,
totp_url: str | None = None,
error_code_mapping: dict[str, str] | None = None,
proxy_location: ProxyLocation | None = None,
max_steps: int | None = None,
browser_session_id: str | None = None,
publish_workflow: bool = False,
) -> TaskRunResponse:
task_run_obj = await self.client.agent.run_task(
prompt=prompt,
url=url,
title=title,
engine=engine,
webhook_url=webhook_url,
totp_identifier=totp_identifier,
totp_url=totp_url,
error_code_mapping=error_code_mapping,
proxy_location=proxy_location,
max_steps=max_steps,
browser_session_id=browser_session_id,
publish_workflow=publish_workflow,
user_agent=self.user_agent,
)
return TaskRunResponse.model_validate(task_run_obj.dict())
async def run_workflow(
self,
workflow_id: str,
workflow_input: dict | None = None,
webhook_url: str | None = None,
proxy_location: ProxyLocation | None = None,
totp_identifier: str | None = None,
totp_url: str | None = None,
browser_session_id: str | None = None,
template: bool = False,
) -> WorkflowRunResponse:
workflow_run_obj = await self.client.agent.run_workflow(
workflow_id=workflow_id,
data=workflow_input,
webhook_callback_url=webhook_url,
proxy_location=proxy_location,
totp_identifier=totp_identifier,
totp_url=totp_url,
browser_session_id=browser_session_id,
template=template,
user_agent=self.user_agent,
)
return WorkflowRunResponse.model_validate(workflow_run_obj.dict())
async def get_run(
self,
run_id: str,
) -> RunResponse:
run_obj = await self.client.agent.get_run(run_id=run_id)
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
return TaskRunResponse.model_validate(run_obj.dict())
elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj.dict())
raise ValueError(f"Invalid run type: {run_obj.run_type}")

View File

@@ -15,10 +15,10 @@ import uvicorn
from dotenv import load_dotenv, set_key from dotenv import load_dotenv, set_key
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from skyvern.agent import SkyvernAgent
from skyvern.config import settings from skyvern.config import settings
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.library import Skyvern
from skyvern.utils import detect_os, get_windows_appdata_roaming, migrate_db from skyvern.utils import detect_os, get_windows_appdata_roaming, migrate_db
load_dotenv() load_dotenv()
@@ -46,7 +46,7 @@ async def skyvern_run_task(prompt: str, url: str) -> dict[str, str]:
NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job") NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job")
url: The starting URL of the website where the task should be performed url: The starting URL of the website where the task should be performed
""" """
skyvern_agent = SkyvernAgent( skyvern_agent = Skyvern(
base_url=settings.SKYVERN_BASE_URL, base_url=settings.SKYVERN_BASE_URL,
api_key=settings.SKYVERN_API_KEY, api_key=settings.SKYVERN_API_KEY,
) )
@@ -562,7 +562,7 @@ async def _setup_local_organization() -> str:
""" """
Returns the API key for the local organization generated Returns the API key for the local organization generated
""" """
skyvern_agent = SkyvernAgent( skyvern_agent = Skyvern(
base_url=settings.SKYVERN_BASE_URL, base_url=settings.SKYVERN_BASE_URL,
api_key=settings.SKYVERN_API_KEY, api_key=settings.SKYVERN_API_KEY,
) )

View File

@@ -346,7 +346,7 @@ class AgentClient:
api_key="YOUR_API_KEY", api_key="YOUR_API_KEY",
authorization="YOUR_AUTHORIZATION", authorization="YOUR_AUTHORIZATION",
) )
client.agent.run_task( await client.agent.run_task(
prompt="prompt", prompt="prompt",
) )
""" """

View File

@@ -23,7 +23,7 @@ from skyvern.webeye.schemas import BrowserSessionResponse
responses={ responses={
200: {"description": "Successfully retrieved browser session details"}, 200: {"description": "Successfully retrieved browser session details"},
404: {"description": "Browser session not found"}, 404: {"description": "Browser session not found"},
401: {"description": "Unauthorized - Invalid or missing authentication"}, 403: {"description": "Unauthorized - Invalid or missing authentication"},
}, },
) )
async def get_browser_session( async def get_browser_session(
@@ -52,7 +52,7 @@ async def get_browser_session(
summary="Get all active browser sessions", summary="Get all active browser sessions",
responses={ responses={
200: {"description": "Successfully retrieved all active browser sessions"}, 200: {"description": "Successfully retrieved all active browser sessions"},
401: {"description": "Unauthorized - Invalid or missing authentication"}, 403: {"description": "Unauthorized - Invalid or missing authentication"},
}, },
) )
async def get_browser_sessions( async def get_browser_sessions(
@@ -76,7 +76,7 @@ async def get_browser_sessions(
summary="Create a new browser session", summary="Create a new browser session",
responses={ responses={
200: {"description": "Successfully created browser session"}, 200: {"description": "Successfully created browser session"},
401: {"description": "Unauthorized - Invalid or missing authentication"}, 403: {"description": "Unauthorized - Invalid or missing authentication"},
}, },
) )
async def create_browser_session( async def create_browser_session(
@@ -101,7 +101,7 @@ async def create_browser_session(
summary="Close a browser session", summary="Close a browser session",
responses={ responses={
200: {"description": "Successfully closed browser session"}, 200: {"description": "Successfully closed browser session"},
401: {"description": "Unauthorized - Invalid or missing authentication"}, 403: {"description": "Unauthorized - Invalid or missing authentication"},
}, },
) )
async def close_browser_session( async def close_browser_session(

View File

@@ -0,0 +1,3 @@
from skyvern.library.skyvern import Skyvern
__all__ = ["Skyvern"]

View File

@@ -1,12 +1,17 @@
import asyncio import asyncio
import os import os
import subprocess import subprocess
from typing import Any, cast import typing
from typing import Any
import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
from skyvern.client import AsyncSkyvern from skyvern.client import AsyncSkyvern
from skyvern.client.agent.types.agent_get_run_response import AgentGetRunResponse
from skyvern.client.core.pydantic_utilities import parse_obj_as
from skyvern.client.environment import SkyvernEnvironment
from skyvern.client.types.task_run_response import TaskRunResponse
from skyvern.config import settings from skyvern.config import settings
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.core import security, skyvern_context from skyvern.forge.sdk.core import security, skyvern_context
@@ -18,21 +23,34 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse from skyvern.library.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
from skyvern.schemas.runs import CUA_ENGINES, ProxyLocation, RunEngine, RunType
from skyvern.services import run_service, task_v1_service, task_v2_service from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.utils import migrate_db from skyvern.utils import migrate_db
class SkyvernAgent: class Skyvern(AsyncSkyvern):
def __init__( def __init__(
self, self,
*,
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
cdp_url: str | None = None, cdp_url: str | None = None,
browser_path: str | None = None, browser_path: str | None = None,
browser_type: str | None = None, browser_type: str | None = None,
environment: SkyvernEnvironment = SkyvernEnvironment.PRODUCTION,
timeout: float | None = None,
follow_redirects: bool | None = True,
httpx_client: httpx.AsyncClient | None = None,
) -> None: ) -> None:
self.client: AsyncSkyvern | None = None super().__init__(
base_url=base_url,
api_key=api_key,
environment=environment,
timeout=timeout,
follow_redirects=follow_redirects,
httpx_client=httpx_client,
)
if base_url is None and api_key is None: if base_url is None and api_key is None:
if not os.path.exists(".env"): if not os.path.exists(".env"):
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.") raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
@@ -40,7 +58,10 @@ class SkyvernAgent:
load_dotenv(".env") load_dotenv(".env")
migrate_db() migrate_db()
self.cdp_url = cdp_url self._api_key = api_key
self._cdp_url = cdp_url
self._browser_path = browser_path
self._browser_type = browser_type
if browser_path: if browser_path:
# TODO validate browser_path # TODO validate browser_path
# Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox # Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox
@@ -51,9 +72,9 @@ class SkyvernAgent:
if browser_process.poll() is not None: if browser_process.poll() is not None:
raise Exception(f"Failed to open browser. browser_path: {browser_path}") raise Exception(f"Failed to open browser. browser_path: {browser_path}")
self.cdp_url = "http://127.0.0.1:9222" self._cdp_url = "http://127.0.0.1:9222"
settings.BROWSER_TYPE = "cdp-connect" settings.BROWSER_TYPE = "cdp-connect"
settings.BROWSER_REMOTE_DEBUGGING_URL = self.cdp_url settings.BROWSER_REMOTE_DEBUGGING_URL = self._cdp_url
else: else:
raise ValueError( raise ValueError(
f"Unsupported browser or invalid path: {browser_path}. " f"Unsupported browser or invalid path: {browser_path}. "
@@ -65,14 +86,12 @@ class SkyvernAgent:
# raise Exception("browser type is missing") # raise Exception("browser type is missing")
browser_type = "chromium-headful" browser_type = "chromium-headful"
self._browser_type = browser_type
settings.BROWSER_TYPE = browser_type settings.BROWSER_TYPE = browser_type
elif base_url and api_key: elif api_key:
self.client = AsyncSkyvern( self._api_key = api_key
base_url=base_url,
api_key=api_key,
)
else: else:
raise ValueError("base_url and api_key must be both provided") raise ValueError("Initializing Skyvern failed: api_key must be provided")
async def get_organization(self) -> Organization: async def get_organization(self) -> Organization:
organization = await app.DATABASE.get_organization_by_domain("skyvern.local") organization = await app.DATABASE.get_organization_by_domain("skyvern.local")
@@ -95,7 +114,13 @@ class SkyvernAgent:
) )
return organization return organization
async def _run_task(self, organization: Organization, task: Task, max_steps: int | None = None) -> None: async def _run_task(
self,
organization: Organization,
task: Task,
max_steps: int | None = None,
engine: RunEngine = RunEngine.skyvern_v1,
) -> None:
org_auth_token = await app.DATABASE.get_valid_org_auth_token( org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=organization.organization_id, organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api, token_type=OrganizationAuthTokenType.api,
@@ -127,6 +152,7 @@ class SkyvernAgent:
task=updated_task, task=updated_task,
step=step, step=step,
api_key=org_auth_token.token if org_auth_token else None, api_key=org_auth_token.token if org_auth_token else None,
engine=engine,
) )
finally: finally:
skyvern_context.reset() skyvern_context.reset()
@@ -246,17 +272,23 @@ class SkyvernAgent:
await asyncio.sleep(1) await asyncio.sleep(1)
############### officially supported interfaces ############### ############### officially supported interfaces ###############
async def get_run(self, run_id: str) -> RunResponse | None: async def get_run(self, run_id: str) -> AgentGetRunResponse | None:
if not self.client: if not self._api_key:
organization = await self.get_organization() organization = await self.get_organization()
return await run_service.get_run_response(run_id, organization_id=organization.organization_id) get_run_internal_resp = await run_service.get_run_response(
run_id, organization_id=organization.organization_id
)
if not get_run_internal_resp:
return None
return typing.cast(
AgentGetRunResponse,
parse_obj_as(
type_=AgentGetRunResponse, # type: ignore
object_=get_run_internal_resp.model_dump(),
),
)
run_obj = await self.client.get_run(run_id) return await self.agent.get_run(run_id)
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
return TaskRunResponse.model_validate(run_obj.dict())
elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj.dict())
raise ValueError(f"Invalid run type: {run_obj.run_type}")
async def run_task( async def run_task(
self, self,
@@ -276,8 +308,8 @@ class SkyvernAgent:
browser_session_id: str | None = None, browser_session_id: str | None = None,
user_agent: str | None = None, user_agent: str | None = None,
) -> TaskRunResponse: ) -> TaskRunResponse:
if not self.client: if not self._api_key:
if engine == RunEngine.skyvern_v1: if engine == RunEngine.skyvern_v1 or engine in CUA_ENGINES:
data_extraction_goal = None data_extraction_goal = None
navigation_goal = prompt navigation_goal = prompt
navigation_payload = None navigation_payload = None
@@ -316,13 +348,15 @@ class SkyvernAgent:
url_hash=url_hash, url_hash=url_hash,
) )
try: try:
await self._run_task(organization, created_task) await self._run_task(organization, created_task, engine=engine)
run_obj = await self.get_run(run_id=created_task.task_id) run_obj = await self.get_run(run_id=created_task.task_id)
return cast(TaskRunResponse, run_obj)
except Exception: except Exception:
# TODO: better error handling and logging # TODO: better error handling and logging
run_obj = await self.get_run(run_id=created_task.task_id) run_obj = await self.get_run(run_id=created_task.task_id)
return cast(TaskRunResponse, run_obj) if not run_obj:
raise Exception("Failed to get the task run after creating the task.")
return from_run_to_task_run_response(run_obj)
elif engine == RunEngine.skyvern_v2: elif engine == RunEngine.skyvern_v2:
# initialize task v2 # initialize task v2
organization = await self.get_organization() organization = await self.get_organization()
@@ -343,11 +377,13 @@ class SkyvernAgent:
await self._run_task_v2(organization, task_v2) await self._run_task_v2(organization, task_v2)
run_obj = await self.get_run(run_id=task_v2.observer_cruise_id) run_obj = await self.get_run(run_id=task_v2.observer_cruise_id)
return cast(TaskRunResponse, run_obj) if not run_obj:
raise Exception("Failed to get the task run after creating the task.")
return from_run_to_task_run_response(run_obj)
else: else:
raise ValueError("Local mode is not supported for this method") raise ValueError("Local mode is not supported for this method")
task_run = await self.client.agent.run_task( task_run = await self.agent.run_task(
prompt=prompt, prompt=prompt,
engine=engine, engine=engine,
url=url, url=url,
@@ -364,26 +400,12 @@ class SkyvernAgent:
if wait_for_completion: if wait_for_completion:
async with asyncio.timeout(timeout): async with asyncio.timeout(timeout):
while True: while True:
task_run = await self.client.agent.get_run(task_run.run_id) task_run = await self.agent.get_run(task_run.run_id)
if task_run.status.is_final(): if task_run.status.is_final():
break break
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL) await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
return TaskRunResponse.model_validate(task_run.dict()) return TaskRunResponse.model_validate(task_run.dict())
async def run_workflow(
self, def from_run_to_task_run_response(run_obj: AgentGetRunResponse) -> TaskRunResponse:
workflow_id: str, return TaskRunResponse.model_validate(run_obj.model_dump())
parameters: dict[str, Any],
webhook_url: str | None = None,
totp_identifier: str | None = None,
totp_url: str | None = None,
title: str | None = None,
error_code_mapping: dict[str, str] | None = None,
proxy_location: ProxyLocation | None = None,
max_steps: int | None = None,
wait_for_completion: bool = True,
timeout: float = DEFAULT_AGENT_TIMEOUT,
browser_session_id: str | None = None,
user_agent: str | None = None,
) -> None:
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")

View File

@@ -255,9 +255,9 @@ class WorkflowRunRequest(BaseModel):
workflow_id: str = Field( workflow_id: str = Field(
description="ID of the workflow to run. Workflow ID starts with `wpid_`.", examples=["wpid_123"] description="ID of the workflow to run. Workflow ID starts with `wpid_`.", examples=["wpid_123"]
) )
title: str | None = Field(default=None, description="The title for this workflow run")
parameters: dict[str, Any] = Field(default={}, description="Parameters to pass to the workflow") parameters: dict[str, Any] = Field(default={}, description="Parameters to pass to the workflow")
proxy_location: ProxyLocation = Field( title: str | None = Field(default=None, description="The title for this workflow run")
proxy_location: ProxyLocation | None = Field(
default=ProxyLocation.RESIDENTIAL, description="Location of proxy to use for this workflow run" default=ProxyLocation.RESIDENTIAL, description="Location of proxy to use for this workflow run"
) )
webhook_url: str | None = Field( webhook_url: str | None = Field(

View File

@@ -11,6 +11,12 @@ from skyvern.services import task_v1_service, task_v2_service, workflow_service
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None: async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
run = await app.DATABASE.get_run(run_id, organization_id=organization_id) run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
if not run:
# try to see if it's a workflow run id for task v2
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(run_id, organization_id=organization_id)
if task_v2:
run = await app.DATABASE.get_run(task_v2.observer_cruise_id, organization_id=organization_id)
if not run: if not run:
return None return None