migrate library facing code to one interface from skyvern import Skyvern (#2368)

This commit is contained in:
Shuchang Zheng
2025-05-16 17:55:46 -07:00
committed by GitHub
parent 1da95bee93
commit 89fd604022
11 changed files with 92 additions and 152 deletions

View File

@@ -14,6 +14,6 @@ setup_logger()
from skyvern.forge import app # noqa: E402, F401
from skyvern.agent import SkyvernAgent # noqa: E402
from skyvern.library import Skyvern # noqa: E402
__all__ = ["SkyvernAgent"]
__all__ = ["Skyvern"]

View File

@@ -1,3 +0,0 @@
from skyvern.agent.agent import SkyvernAgent
__all__ = ["SkyvernAgent"]

View File

@@ -1,88 +0,0 @@
from skyvern.client.client import AsyncSkyvern
from skyvern.config import settings
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
class SkyvernClient:
def __init__(
self,
base_url: str = settings.SKYVERN_BASE_URL,
api_key: str = settings.SKYVERN_API_KEY,
extra_headers: dict[str, str] | None = None,
) -> None:
self.base_url = base_url
self.api_key = api_key
self.client = AsyncSkyvern(base_url=base_url, api_key=api_key)
self.extra_headers = extra_headers or {}
self.user_agent = None
if "X-User-Agent" in self.extra_headers:
self.user_agent = self.extra_headers["X-User-Agent"]
elif "x-user-agent" in self.extra_headers:
self.user_agent = self.extra_headers["x-user-agent"]
async def run_task(
self,
prompt: str,
url: str | None = None,
title: str | None = None,
engine: RunEngine = RunEngine.skyvern_v2,
webhook_url: str | None = None,
totp_identifier: str | None = None,
totp_url: str | None = None,
error_code_mapping: dict[str, str] | None = None,
proxy_location: ProxyLocation | None = None,
max_steps: int | None = None,
browser_session_id: str | None = None,
publish_workflow: bool = False,
) -> TaskRunResponse:
task_run_obj = await self.client.agent.run_task(
prompt=prompt,
url=url,
title=title,
engine=engine,
webhook_url=webhook_url,
totp_identifier=totp_identifier,
totp_url=totp_url,
error_code_mapping=error_code_mapping,
proxy_location=proxy_location,
max_steps=max_steps,
browser_session_id=browser_session_id,
publish_workflow=publish_workflow,
user_agent=self.user_agent,
)
return TaskRunResponse.model_validate(task_run_obj.dict())
async def run_workflow(
self,
workflow_id: str,
workflow_input: dict | None = None,
webhook_url: str | None = None,
proxy_location: ProxyLocation | None = None,
totp_identifier: str | None = None,
totp_url: str | None = None,
browser_session_id: str | None = None,
template: bool = False,
) -> WorkflowRunResponse:
workflow_run_obj = await self.client.agent.run_workflow(
workflow_id=workflow_id,
data=workflow_input,
webhook_callback_url=webhook_url,
proxy_location=proxy_location,
totp_identifier=totp_identifier,
totp_url=totp_url,
browser_session_id=browser_session_id,
template=template,
user_agent=self.user_agent,
)
return WorkflowRunResponse.model_validate(workflow_run_obj.dict())
async def get_run(
self,
run_id: str,
) -> RunResponse:
run_obj = await self.client.agent.get_run(run_id=run_id)
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
return TaskRunResponse.model_validate(run_obj.dict())
elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj.dict())
raise ValueError(f"Invalid run type: {run_obj.run_type}")

View File

@@ -15,10 +15,10 @@ import uvicorn
from dotenv import load_dotenv, set_key
from mcp.server.fastmcp import FastMCP
from skyvern.agent import SkyvernAgent
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.library import Skyvern
from skyvern.utils import detect_os, get_windows_appdata_roaming, migrate_db
load_dotenv()
@@ -46,7 +46,7 @@ async def skyvern_run_task(prompt: str, url: str) -> dict[str, str]:
NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job")
url: The starting URL of the website where the task should be performed
"""
skyvern_agent = SkyvernAgent(
skyvern_agent = Skyvern(
base_url=settings.SKYVERN_BASE_URL,
api_key=settings.SKYVERN_API_KEY,
)
@@ -562,7 +562,7 @@ async def _setup_local_organization() -> str:
"""
Returns the API key for the local organization generated
"""
skyvern_agent = SkyvernAgent(
skyvern_agent = Skyvern(
base_url=settings.SKYVERN_BASE_URL,
api_key=settings.SKYVERN_API_KEY,
)

View File

@@ -346,7 +346,7 @@ class AgentClient:
api_key="YOUR_API_KEY",
authorization="YOUR_AUTHORIZATION",
)
client.agent.run_task(
await client.agent.run_task(
prompt="prompt",
)
"""

View File

@@ -23,7 +23,7 @@ from skyvern.webeye.schemas import BrowserSessionResponse
responses={
200: {"description": "Successfully retrieved browser session details"},
404: {"description": "Browser session not found"},
401: {"description": "Unauthorized - Invalid or missing authentication"},
403: {"description": "Unauthorized - Invalid or missing authentication"},
},
)
async def get_browser_session(
@@ -52,7 +52,7 @@ async def get_browser_session(
summary="Get all active browser sessions",
responses={
200: {"description": "Successfully retrieved all active browser sessions"},
401: {"description": "Unauthorized - Invalid or missing authentication"},
403: {"description": "Unauthorized - Invalid or missing authentication"},
},
)
async def get_browser_sessions(
@@ -76,7 +76,7 @@ async def get_browser_sessions(
summary="Create a new browser session",
responses={
200: {"description": "Successfully created browser session"},
401: {"description": "Unauthorized - Invalid or missing authentication"},
403: {"description": "Unauthorized - Invalid or missing authentication"},
},
)
async def create_browser_session(
@@ -101,7 +101,7 @@ async def create_browser_session(
summary="Close a browser session",
responses={
200: {"description": "Successfully closed browser session"},
401: {"description": "Unauthorized - Invalid or missing authentication"},
403: {"description": "Unauthorized - Invalid or missing authentication"},
},
)
async def close_browser_session(

View File

@@ -0,0 +1,3 @@
from skyvern.library.skyvern import Skyvern
__all__ = ["Skyvern"]

View File

@@ -1,12 +1,17 @@
import asyncio
import os
import subprocess
from typing import Any, cast
import typing
from typing import Any
import httpx
from dotenv import load_dotenv
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
from skyvern.client import AsyncSkyvern
from skyvern.client.agent.types.agent_get_run_response import AgentGetRunResponse
from skyvern.client.core.pydantic_utilities import parse_obj_as
from skyvern.client.environment import SkyvernEnvironment
from skyvern.client.types.task_run_response import TaskRunResponse
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.core import security, skyvern_context
@@ -18,21 +23,34 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
from skyvern.library.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
from skyvern.schemas.runs import CUA_ENGINES, ProxyLocation, RunEngine, RunType
from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.utils import migrate_db
class SkyvernAgent:
class Skyvern(AsyncSkyvern):
def __init__(
self,
*,
base_url: str | None = None,
api_key: str | None = None,
cdp_url: str | None = None,
browser_path: str | None = None,
browser_type: str | None = None,
environment: SkyvernEnvironment = SkyvernEnvironment.PRODUCTION,
timeout: float | None = None,
follow_redirects: bool | None = True,
httpx_client: httpx.AsyncClient | None = None,
) -> None:
self.client: AsyncSkyvern | None = None
super().__init__(
base_url=base_url,
api_key=api_key,
environment=environment,
timeout=timeout,
follow_redirects=follow_redirects,
httpx_client=httpx_client,
)
if base_url is None and api_key is None:
if not os.path.exists(".env"):
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
@@ -40,7 +58,10 @@ class SkyvernAgent:
load_dotenv(".env")
migrate_db()
self.cdp_url = cdp_url
self._api_key = api_key
self._cdp_url = cdp_url
self._browser_path = browser_path
self._browser_type = browser_type
if browser_path:
# TODO validate browser_path
# Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox
@@ -51,9 +72,9 @@ class SkyvernAgent:
if browser_process.poll() is not None:
raise Exception(f"Failed to open browser. browser_path: {browser_path}")
self.cdp_url = "http://127.0.0.1:9222"
self._cdp_url = "http://127.0.0.1:9222"
settings.BROWSER_TYPE = "cdp-connect"
settings.BROWSER_REMOTE_DEBUGGING_URL = self.cdp_url
settings.BROWSER_REMOTE_DEBUGGING_URL = self._cdp_url
else:
raise ValueError(
f"Unsupported browser or invalid path: {browser_path}. "
@@ -65,14 +86,12 @@ class SkyvernAgent:
# raise Exception("browser type is missing")
browser_type = "chromium-headful"
self._browser_type = browser_type
settings.BROWSER_TYPE = browser_type
elif base_url and api_key:
self.client = AsyncSkyvern(
base_url=base_url,
api_key=api_key,
)
elif api_key:
self._api_key = api_key
else:
raise ValueError("base_url and api_key must be both provided")
raise ValueError("Initializing Skyvern failed: api_key must be provided")
async def get_organization(self) -> Organization:
organization = await app.DATABASE.get_organization_by_domain("skyvern.local")
@@ -95,7 +114,13 @@ class SkyvernAgent:
)
return organization
async def _run_task(self, organization: Organization, task: Task, max_steps: int | None = None) -> None:
async def _run_task(
self,
organization: Organization,
task: Task,
max_steps: int | None = None,
engine: RunEngine = RunEngine.skyvern_v1,
) -> None:
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
@@ -127,6 +152,7 @@ class SkyvernAgent:
task=updated_task,
step=step,
api_key=org_auth_token.token if org_auth_token else None,
engine=engine,
)
finally:
skyvern_context.reset()
@@ -246,17 +272,23 @@ class SkyvernAgent:
await asyncio.sleep(1)
############### officially supported interfaces ###############
async def get_run(self, run_id: str) -> RunResponse | None:
if not self.client:
async def get_run(self, run_id: str) -> AgentGetRunResponse | None:
if not self._api_key:
organization = await self.get_organization()
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
get_run_internal_resp = await run_service.get_run_response(
run_id, organization_id=organization.organization_id
)
if not get_run_internal_resp:
return None
return typing.cast(
AgentGetRunResponse,
parse_obj_as(
type_=AgentGetRunResponse, # type: ignore
object_=get_run_internal_resp.model_dump(),
),
)
run_obj = await self.client.get_run(run_id)
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
return TaskRunResponse.model_validate(run_obj.dict())
elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj.dict())
raise ValueError(f"Invalid run type: {run_obj.run_type}")
return await self.agent.get_run(run_id)
async def run_task(
self,
@@ -276,8 +308,8 @@ class SkyvernAgent:
browser_session_id: str | None = None,
user_agent: str | None = None,
) -> TaskRunResponse:
if not self.client:
if engine == RunEngine.skyvern_v1:
if not self._api_key:
if engine == RunEngine.skyvern_v1 or engine in CUA_ENGINES:
data_extraction_goal = None
navigation_goal = prompt
navigation_payload = None
@@ -316,13 +348,15 @@ class SkyvernAgent:
url_hash=url_hash,
)
try:
await self._run_task(organization, created_task)
await self._run_task(organization, created_task, engine=engine)
run_obj = await self.get_run(run_id=created_task.task_id)
return cast(TaskRunResponse, run_obj)
except Exception:
# TODO: better error handling and logging
run_obj = await self.get_run(run_id=created_task.task_id)
return cast(TaskRunResponse, run_obj)
if not run_obj:
raise Exception("Failed to get the task run after creating the task.")
return from_run_to_task_run_response(run_obj)
elif engine == RunEngine.skyvern_v2:
# initialize task v2
organization = await self.get_organization()
@@ -343,11 +377,13 @@ class SkyvernAgent:
await self._run_task_v2(organization, task_v2)
run_obj = await self.get_run(run_id=task_v2.observer_cruise_id)
return cast(TaskRunResponse, run_obj)
if not run_obj:
raise Exception("Failed to get the task run after creating the task.")
return from_run_to_task_run_response(run_obj)
else:
raise ValueError("Local mode is not supported for this method")
task_run = await self.client.agent.run_task(
task_run = await self.agent.run_task(
prompt=prompt,
engine=engine,
url=url,
@@ -364,26 +400,12 @@ class SkyvernAgent:
if wait_for_completion:
async with asyncio.timeout(timeout):
while True:
task_run = await self.client.agent.get_run(task_run.run_id)
task_run = await self.agent.get_run(task_run.run_id)
if task_run.status.is_final():
break
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
return TaskRunResponse.model_validate(task_run.dict())
async def run_workflow(
self,
workflow_id: str,
parameters: dict[str, Any],
webhook_url: str | None = None,
totp_identifier: str | None = None,
totp_url: str | None = None,
title: str | None = None,
error_code_mapping: dict[str, str] | None = None,
proxy_location: ProxyLocation | None = None,
max_steps: int | None = None,
wait_for_completion: bool = True,
timeout: float = DEFAULT_AGENT_TIMEOUT,
browser_session_id: str | None = None,
user_agent: str | None = None,
) -> None:
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")
def from_run_to_task_run_response(run_obj: AgentGetRunResponse) -> TaskRunResponse:
return TaskRunResponse.model_validate(run_obj.model_dump())

View File

@@ -255,9 +255,9 @@ class WorkflowRunRequest(BaseModel):
workflow_id: str = Field(
description="ID of the workflow to run. Workflow ID starts with `wpid_`.", examples=["wpid_123"]
)
title: str | None = Field(default=None, description="The title for this workflow run")
parameters: dict[str, Any] = Field(default={}, description="Parameters to pass to the workflow")
proxy_location: ProxyLocation = Field(
title: str | None = Field(default=None, description="The title for this workflow run")
proxy_location: ProxyLocation | None = Field(
default=ProxyLocation.RESIDENTIAL, description="Location of proxy to use for this workflow run"
)
webhook_url: str | None = Field(

View File

@@ -11,6 +11,12 @@ from skyvern.services import task_v1_service, task_v2_service, workflow_service
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
if not run:
# try to see if it's a workflow run id for task v2
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(run_id, organization_id=organization_id)
if task_v2:
run = await app.DATABASE.get_run(task_v2.observer_cruise_id, organization_id=organization_id)
if not run:
return None