From 89fd604022da4e5fc906e6d6959575bc7a05e304 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Fri, 16 May 2025 17:55:46 -0700 Subject: [PATCH] migrate library facing code to one interface `from skyvern import Skyvern` (#2368) --- skyvern/__init__.py | 4 +- skyvern/agent/__init__.py | 3 - skyvern/agent/client.py | 88 ------------- skyvern/cli/commands.py | 6 +- skyvern/client/agent/client.py | 2 +- skyvern/forge/sdk/routes/browser_sessions.py | 8 +- skyvern/library/__init__.py | 3 + skyvern/{agent => library}/constants.py | 0 .../{agent/agent.py => library/skyvern.py} | 120 +++++++++++------- skyvern/schemas/runs.py | 4 +- skyvern/services/run_service.py | 6 + 11 files changed, 92 insertions(+), 152 deletions(-) delete mode 100644 skyvern/agent/__init__.py delete mode 100644 skyvern/agent/client.py create mode 100644 skyvern/library/__init__.py rename skyvern/{agent => library}/constants.py (100%) rename skyvern/{agent/agent.py => library/skyvern.py} (82%) diff --git a/skyvern/__init__.py b/skyvern/__init__.py index 81048db1..88f066c7 100644 --- a/skyvern/__init__.py +++ b/skyvern/__init__.py @@ -14,6 +14,6 @@ setup_logger() from skyvern.forge import app # noqa: E402, F401 -from skyvern.agent import SkyvernAgent # noqa: E402 +from skyvern.library import Skyvern # noqa: E402 -__all__ = ["SkyvernAgent"] +__all__ = ["Skyvern"] diff --git a/skyvern/agent/__init__.py b/skyvern/agent/__init__.py deleted file mode 100644 index 4e071704..00000000 --- a/skyvern/agent/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from skyvern.agent.agent import SkyvernAgent - -__all__ = ["SkyvernAgent"] diff --git a/skyvern/agent/client.py b/skyvern/agent/client.py deleted file mode 100644 index 39c9b9c2..00000000 --- a/skyvern/agent/client.py +++ /dev/null @@ -1,88 +0,0 @@ -from skyvern.client.client import AsyncSkyvern -from skyvern.config import settings -from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse - - -class SkyvernClient: - def __init__( - self, - base_url: str = settings.SKYVERN_BASE_URL, - api_key: str = settings.SKYVERN_API_KEY, - extra_headers: dict[str, str] | None = None, - ) -> None: - self.base_url = base_url - self.api_key = api_key - self.client = AsyncSkyvern(base_url=base_url, api_key=api_key) - self.extra_headers = extra_headers or {} - self.user_agent = None - if "X-User-Agent" in self.extra_headers: - self.user_agent = self.extra_headers["X-User-Agent"] - elif "x-user-agent" in self.extra_headers: - self.user_agent = self.extra_headers["x-user-agent"] - - async def run_task( - self, - prompt: str, - url: str | None = None, - title: str | None = None, - engine: RunEngine = RunEngine.skyvern_v2, - webhook_url: str | None = None, - totp_identifier: str | None = None, - totp_url: str | None = None, - error_code_mapping: dict[str, str] | None = None, - proxy_location: ProxyLocation | None = None, - max_steps: int | None = None, - browser_session_id: str | None = None, - publish_workflow: bool = False, - ) -> TaskRunResponse: - task_run_obj = await self.client.agent.run_task( - prompt=prompt, - url=url, - title=title, - engine=engine, - webhook_url=webhook_url, - totp_identifier=totp_identifier, - totp_url=totp_url, - error_code_mapping=error_code_mapping, - proxy_location=proxy_location, - max_steps=max_steps, - browser_session_id=browser_session_id, - publish_workflow=publish_workflow, - user_agent=self.user_agent, - ) - return TaskRunResponse.model_validate(task_run_obj.dict()) - - async def run_workflow( - self, - workflow_id: str, - workflow_input: dict | None = None, - webhook_url: str | None = None, - proxy_location: ProxyLocation | None = None, - totp_identifier: str | None = None, - totp_url: str | None = None, - browser_session_id: str | None = None, - template: bool = False, - ) -> WorkflowRunResponse: - workflow_run_obj = await self.client.agent.run_workflow( - workflow_id=workflow_id, - data=workflow_input, - webhook_callback_url=webhook_url, - proxy_location=proxy_location, - totp_identifier=totp_identifier, - totp_url=totp_url, - browser_session_id=browser_session_id, - template=template, - user_agent=self.user_agent, - ) - return WorkflowRunResponse.model_validate(workflow_run_obj.dict()) - - async def get_run( - self, - run_id: str, - ) -> RunResponse: - run_obj = await self.client.agent.get_run(run_id=run_id) - if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]: - return TaskRunResponse.model_validate(run_obj.dict()) - elif run_obj.run_type == RunType.workflow_run: - return WorkflowRunResponse.model_validate(run_obj.dict()) - raise ValueError(f"Invalid run type: {run_obj.run_type}") diff --git a/skyvern/cli/commands.py b/skyvern/cli/commands.py index 0d585f6b..1f7d1b0f 100644 --- a/skyvern/cli/commands.py +++ b/skyvern/cli/commands.py @@ -15,10 +15,10 @@ import uvicorn from dotenv import load_dotenv, set_key from mcp.server.fastmcp import FastMCP -from skyvern.agent import SkyvernAgent from skyvern.config import settings from skyvern.forge import app from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType +from skyvern.library import Skyvern from skyvern.utils import detect_os, get_windows_appdata_roaming, migrate_db load_dotenv() @@ -46,7 +46,7 @@ async def skyvern_run_task(prompt: str, url: str) -> dict[str, str]: NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job") url: The starting URL of the website where the task should be performed """ - skyvern_agent = SkyvernAgent( + skyvern_agent = Skyvern( base_url=settings.SKYVERN_BASE_URL, api_key=settings.SKYVERN_API_KEY, ) @@ -562,7 +562,7 @@ async def _setup_local_organization() -> str: """ Returns the API key for the local organization generated """ - skyvern_agent = SkyvernAgent( + skyvern_agent = Skyvern( base_url=settings.SKYVERN_BASE_URL, api_key=settings.SKYVERN_API_KEY, ) diff --git a/skyvern/client/agent/client.py b/skyvern/client/agent/client.py index 626ab33b..45c7a316 100644 --- a/skyvern/client/agent/client.py +++ b/skyvern/client/agent/client.py @@ -346,7 +346,7 @@ class AgentClient: api_key="YOUR_API_KEY", authorization="YOUR_AUTHORIZATION", ) - client.agent.run_task( + await client.agent.run_task( prompt="prompt", ) """ diff --git a/skyvern/forge/sdk/routes/browser_sessions.py b/skyvern/forge/sdk/routes/browser_sessions.py index f5243896..bc8d8660 100644 --- a/skyvern/forge/sdk/routes/browser_sessions.py +++ b/skyvern/forge/sdk/routes/browser_sessions.py @@ -23,7 +23,7 @@ from skyvern.webeye.schemas import BrowserSessionResponse responses={ 200: {"description": "Successfully retrieved browser session details"}, 404: {"description": "Browser session not found"}, - 401: {"description": "Unauthorized - Invalid or missing authentication"}, + 403: {"description": "Unauthorized - Invalid or missing authentication"}, }, ) async def get_browser_session( @@ -52,7 +52,7 @@ async def get_browser_session( summary="Get all active browser sessions", responses={ 200: {"description": "Successfully retrieved all active browser sessions"}, - 401: {"description": "Unauthorized - Invalid or missing authentication"}, + 403: {"description": "Unauthorized - Invalid or missing authentication"}, }, ) async def get_browser_sessions( @@ -76,7 +76,7 @@ async def get_browser_sessions( summary="Create a new browser session", responses={ 200: {"description": "Successfully created browser session"}, - 401: {"description": "Unauthorized - Invalid or missing authentication"}, + 403: {"description": "Unauthorized - Invalid or missing authentication"}, }, ) async def create_browser_session( @@ -101,7 +101,7 @@ async def create_browser_session( summary="Close a browser session", responses={ 200: {"description": "Successfully closed browser session"}, - 401: {"description": "Unauthorized - Invalid or missing authentication"}, + 403: {"description": "Unauthorized - Invalid or missing authentication"}, }, ) async def close_browser_session( diff --git a/skyvern/library/__init__.py b/skyvern/library/__init__.py new file mode 100644 index 00000000..6f618c7c --- /dev/null +++ b/skyvern/library/__init__.py @@ -0,0 +1,3 @@ +from skyvern.library.skyvern import Skyvern + +__all__ = ["Skyvern"] diff --git a/skyvern/agent/constants.py b/skyvern/library/constants.py similarity index 100% rename from skyvern/agent/constants.py rename to skyvern/library/constants.py diff --git a/skyvern/agent/agent.py b/skyvern/library/skyvern.py similarity index 82% rename from skyvern/agent/agent.py rename to skyvern/library/skyvern.py index 834b34dd..11cc5fe5 100644 --- a/skyvern/agent/agent.py +++ b/skyvern/library/skyvern.py @@ -1,12 +1,17 @@ import asyncio import os import subprocess -from typing import Any, cast +import typing +from typing import Any +import httpx from dotenv import load_dotenv -from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT from skyvern.client import AsyncSkyvern +from skyvern.client.agent.types.agent_get_run_response import AgentGetRunResponse +from skyvern.client.core.pydantic_utilities import parse_obj_as +from skyvern.client.environment import SkyvernEnvironment +from skyvern.client.types.task_run_response import TaskRunResponse from skyvern.config import settings from skyvern.forge import app from skyvern.forge.sdk.core import security, skyvern_context @@ -18,21 +23,34 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus -from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse +from skyvern.library.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT +from skyvern.schemas.runs import CUA_ENGINES, ProxyLocation, RunEngine, RunType from skyvern.services import run_service, task_v1_service, task_v2_service from skyvern.utils import migrate_db -class SkyvernAgent: +class Skyvern(AsyncSkyvern): def __init__( self, + *, base_url: str | None = None, api_key: str | None = None, cdp_url: str | None = None, browser_path: str | None = None, browser_type: str | None = None, + environment: SkyvernEnvironment = SkyvernEnvironment.PRODUCTION, + timeout: float | None = None, + follow_redirects: bool | None = True, + httpx_client: httpx.AsyncClient | None = None, ) -> None: - self.client: AsyncSkyvern | None = None + super().__init__( + base_url=base_url, + api_key=api_key, + environment=environment, + timeout=timeout, + follow_redirects=follow_redirects, + httpx_client=httpx_client, + ) if base_url is None and api_key is None: if not os.path.exists(".env"): raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.") @@ -40,7 +58,10 @@ class SkyvernAgent: load_dotenv(".env") migrate_db() - self.cdp_url = cdp_url + self._api_key = api_key + self._cdp_url = cdp_url + self._browser_path = browser_path + self._browser_type = browser_type if browser_path: # TODO validate browser_path # Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox @@ -51,9 +72,9 @@ class SkyvernAgent: if browser_process.poll() is not None: raise Exception(f"Failed to open browser. browser_path: {browser_path}") - self.cdp_url = "http://127.0.0.1:9222" + self._cdp_url = "http://127.0.0.1:9222" settings.BROWSER_TYPE = "cdp-connect" - settings.BROWSER_REMOTE_DEBUGGING_URL = self.cdp_url + settings.BROWSER_REMOTE_DEBUGGING_URL = self._cdp_url else: raise ValueError( f"Unsupported browser or invalid path: {browser_path}. " @@ -65,14 +86,12 @@ class SkyvernAgent: # raise Exception("browser type is missing") browser_type = "chromium-headful" + self._browser_type = browser_type settings.BROWSER_TYPE = browser_type - elif base_url and api_key: - self.client = AsyncSkyvern( - base_url=base_url, - api_key=api_key, - ) + elif api_key: + self._api_key = api_key else: - raise ValueError("base_url and api_key must be both provided") + raise ValueError("Initializing Skyvern failed: api_key must be provided") async def get_organization(self) -> Organization: organization = await app.DATABASE.get_organization_by_domain("skyvern.local") @@ -95,7 +114,13 @@ class SkyvernAgent: ) return organization - async def _run_task(self, organization: Organization, task: Task, max_steps: int | None = None) -> None: + async def _run_task( + self, + organization: Organization, + task: Task, + max_steps: int | None = None, + engine: RunEngine = RunEngine.skyvern_v1, + ) -> None: org_auth_token = await app.DATABASE.get_valid_org_auth_token( organization_id=organization.organization_id, token_type=OrganizationAuthTokenType.api, @@ -127,6 +152,7 @@ class SkyvernAgent: task=updated_task, step=step, api_key=org_auth_token.token if org_auth_token else None, + engine=engine, ) finally: skyvern_context.reset() @@ -246,17 +272,23 @@ class SkyvernAgent: await asyncio.sleep(1) ############### officially supported interfaces ############### - async def get_run(self, run_id: str) -> RunResponse | None: - if not self.client: + async def get_run(self, run_id: str) -> AgentGetRunResponse | None: + if not self._api_key: organization = await self.get_organization() - return await run_service.get_run_response(run_id, organization_id=organization.organization_id) + get_run_internal_resp = await run_service.get_run_response( + run_id, organization_id=organization.organization_id + ) + if not get_run_internal_resp: + return None + return typing.cast( + AgentGetRunResponse, + parse_obj_as( + type_=AgentGetRunResponse, # type: ignore + object_=get_run_internal_resp.model_dump(), + ), + ) - run_obj = await self.client.get_run(run_id) - if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]: - return TaskRunResponse.model_validate(run_obj.dict()) - elif run_obj.run_type == RunType.workflow_run: - return WorkflowRunResponse.model_validate(run_obj.dict()) - raise ValueError(f"Invalid run type: {run_obj.run_type}") + return await self.agent.get_run(run_id) async def run_task( self, @@ -276,8 +308,8 @@ class SkyvernAgent: browser_session_id: str | None = None, user_agent: str | None = None, ) -> TaskRunResponse: - if not self.client: - if engine == RunEngine.skyvern_v1: + if not self._api_key: + if engine == RunEngine.skyvern_v1 or engine in CUA_ENGINES: data_extraction_goal = None navigation_goal = prompt navigation_payload = None @@ -316,13 +348,15 @@ class SkyvernAgent: url_hash=url_hash, ) try: - await self._run_task(organization, created_task) + await self._run_task(organization, created_task, engine=engine) run_obj = await self.get_run(run_id=created_task.task_id) - return cast(TaskRunResponse, run_obj) except Exception: # TODO: better error handling and logging run_obj = await self.get_run(run_id=created_task.task_id) - return cast(TaskRunResponse, run_obj) + if not run_obj: + raise Exception("Failed to get the task run after creating the task.") + return from_run_to_task_run_response(run_obj) + elif engine == RunEngine.skyvern_v2: # initialize task v2 organization = await self.get_organization() @@ -343,11 +377,13 @@ class SkyvernAgent: await self._run_task_v2(organization, task_v2) run_obj = await self.get_run(run_id=task_v2.observer_cruise_id) - return cast(TaskRunResponse, run_obj) + if not run_obj: + raise Exception("Failed to get the task run after creating the task.") + return from_run_to_task_run_response(run_obj) else: raise ValueError("Local mode is not supported for this method") - task_run = await self.client.agent.run_task( + task_run = await self.agent.run_task( prompt=prompt, engine=engine, url=url, @@ -364,26 +400,12 @@ class SkyvernAgent: if wait_for_completion: async with asyncio.timeout(timeout): while True: - task_run = await self.client.agent.get_run(task_run.run_id) + task_run = await self.agent.get_run(task_run.run_id) if task_run.status.is_final(): break await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL) return TaskRunResponse.model_validate(task_run.dict()) - async def run_workflow( - self, - workflow_id: str, - parameters: dict[str, Any], - webhook_url: str | None = None, - totp_identifier: str | None = None, - totp_url: str | None = None, - title: str | None = None, - error_code_mapping: dict[str, str] | None = None, - proxy_location: ProxyLocation | None = None, - max_steps: int | None = None, - wait_for_completion: bool = True, - timeout: float = DEFAULT_AGENT_TIMEOUT, - browser_session_id: str | None = None, - user_agent: str | None = None, - ) -> None: - raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.") + +def from_run_to_task_run_response(run_obj: AgentGetRunResponse) -> TaskRunResponse: + return TaskRunResponse.model_validate(run_obj.model_dump()) diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index 3475eb96..785efc6c 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -255,9 +255,9 @@ class WorkflowRunRequest(BaseModel): workflow_id: str = Field( description="ID of the workflow to run. Workflow ID starts with `wpid_`.", examples=["wpid_123"] ) - title: str | None = Field(default=None, description="The title for this workflow run") parameters: dict[str, Any] = Field(default={}, description="Parameters to pass to the workflow") - proxy_location: ProxyLocation = Field( + title: str | None = Field(default=None, description="The title for this workflow run") + proxy_location: ProxyLocation | None = Field( default=ProxyLocation.RESIDENTIAL, description="Location of proxy to use for this workflow run" ) webhook_url: str | None = Field( diff --git a/skyvern/services/run_service.py b/skyvern/services/run_service.py index af1692bb..8a1053fb 100644 --- a/skyvern/services/run_service.py +++ b/skyvern/services/run_service.py @@ -11,6 +11,12 @@ from skyvern.services import task_v1_service, task_v2_service, workflow_service async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None: run = await app.DATABASE.get_run(run_id, organization_id=organization_id) + if not run: + # try to see if it's a workflow run id for task v2 + task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(run_id, organization_id=organization_id) + if task_v2: + run = await app.DATABASE.get_run(task_v2.observer_cruise_id, organization_id=organization_id) + if not run: return None