migrate library facing code to one interface from skyvern import Skyvern (#2368)
This commit is contained in:
@@ -14,6 +14,6 @@ setup_logger()
|
||||
|
||||
|
||||
from skyvern.forge import app # noqa: E402, F401
|
||||
from skyvern.agent import SkyvernAgent # noqa: E402
|
||||
from skyvern.library import Skyvern # noqa: E402
|
||||
|
||||
__all__ = ["SkyvernAgent"]
|
||||
__all__ = ["Skyvern"]
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from skyvern.agent.agent import SkyvernAgent
|
||||
|
||||
__all__ = ["SkyvernAgent"]
|
||||
@@ -1,88 +0,0 @@
|
||||
from skyvern.client.client import AsyncSkyvern
|
||||
from skyvern.config import settings
|
||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
|
||||
|
||||
|
||||
class SkyvernClient:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = settings.SKYVERN_BASE_URL,
|
||||
api_key: str = settings.SKYVERN_API_KEY,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.client = AsyncSkyvern(base_url=base_url, api_key=api_key)
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.user_agent = None
|
||||
if "X-User-Agent" in self.extra_headers:
|
||||
self.user_agent = self.extra_headers["X-User-Agent"]
|
||||
elif "x-user-agent" in self.extra_headers:
|
||||
self.user_agent = self.extra_headers["x-user-agent"]
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
prompt: str,
|
||||
url: str | None = None,
|
||||
title: str | None = None,
|
||||
engine: RunEngine = RunEngine.skyvern_v2,
|
||||
webhook_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
totp_url: str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
max_steps: int | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
publish_workflow: bool = False,
|
||||
) -> TaskRunResponse:
|
||||
task_run_obj = await self.client.agent.run_task(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
title=title,
|
||||
engine=engine,
|
||||
webhook_url=webhook_url,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_url=totp_url,
|
||||
error_code_mapping=error_code_mapping,
|
||||
proxy_location=proxy_location,
|
||||
max_steps=max_steps,
|
||||
browser_session_id=browser_session_id,
|
||||
publish_workflow=publish_workflow,
|
||||
user_agent=self.user_agent,
|
||||
)
|
||||
return TaskRunResponse.model_validate(task_run_obj.dict())
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
workflow_input: dict | None = None,
|
||||
webhook_url: str | None = None,
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
totp_url: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
template: bool = False,
|
||||
) -> WorkflowRunResponse:
|
||||
workflow_run_obj = await self.client.agent.run_workflow(
|
||||
workflow_id=workflow_id,
|
||||
data=workflow_input,
|
||||
webhook_callback_url=webhook_url,
|
||||
proxy_location=proxy_location,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_url=totp_url,
|
||||
browser_session_id=browser_session_id,
|
||||
template=template,
|
||||
user_agent=self.user_agent,
|
||||
)
|
||||
return WorkflowRunResponse.model_validate(workflow_run_obj.dict())
|
||||
|
||||
async def get_run(
|
||||
self,
|
||||
run_id: str,
|
||||
) -> RunResponse:
|
||||
run_obj = await self.client.agent.get_run(run_id=run_id)
|
||||
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
|
||||
return TaskRunResponse.model_validate(run_obj.dict())
|
||||
elif run_obj.run_type == RunType.workflow_run:
|
||||
return WorkflowRunResponse.model_validate(run_obj.dict())
|
||||
raise ValueError(f"Invalid run type: {run_obj.run_type}")
|
||||
@@ -15,10 +15,10 @@ import uvicorn
|
||||
from dotenv import load_dotenv, set_key
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from skyvern.agent import SkyvernAgent
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.library import Skyvern
|
||||
from skyvern.utils import detect_os, get_windows_appdata_roaming, migrate_db
|
||||
|
||||
load_dotenv()
|
||||
@@ -46,7 +46,7 @@ async def skyvern_run_task(prompt: str, url: str) -> dict[str, str]:
|
||||
NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job")
|
||||
url: The starting URL of the website where the task should be performed
|
||||
"""
|
||||
skyvern_agent = SkyvernAgent(
|
||||
skyvern_agent = Skyvern(
|
||||
base_url=settings.SKYVERN_BASE_URL,
|
||||
api_key=settings.SKYVERN_API_KEY,
|
||||
)
|
||||
@@ -562,7 +562,7 @@ async def _setup_local_organization() -> str:
|
||||
"""
|
||||
Returns the API key for the local organization generated
|
||||
"""
|
||||
skyvern_agent = SkyvernAgent(
|
||||
skyvern_agent = Skyvern(
|
||||
base_url=settings.SKYVERN_BASE_URL,
|
||||
api_key=settings.SKYVERN_API_KEY,
|
||||
)
|
||||
|
||||
@@ -346,7 +346,7 @@ class AgentClient:
|
||||
api_key="YOUR_API_KEY",
|
||||
authorization="YOUR_AUTHORIZATION",
|
||||
)
|
||||
client.agent.run_task(
|
||||
await client.agent.run_task(
|
||||
prompt="prompt",
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,7 @@ from skyvern.webeye.schemas import BrowserSessionResponse
|
||||
responses={
|
||||
200: {"description": "Successfully retrieved browser session details"},
|
||||
404: {"description": "Browser session not found"},
|
||||
401: {"description": "Unauthorized - Invalid or missing authentication"},
|
||||
403: {"description": "Unauthorized - Invalid or missing authentication"},
|
||||
},
|
||||
)
|
||||
async def get_browser_session(
|
||||
@@ -52,7 +52,7 @@ async def get_browser_session(
|
||||
summary="Get all active browser sessions",
|
||||
responses={
|
||||
200: {"description": "Successfully retrieved all active browser sessions"},
|
||||
401: {"description": "Unauthorized - Invalid or missing authentication"},
|
||||
403: {"description": "Unauthorized - Invalid or missing authentication"},
|
||||
},
|
||||
)
|
||||
async def get_browser_sessions(
|
||||
@@ -76,7 +76,7 @@ async def get_browser_sessions(
|
||||
summary="Create a new browser session",
|
||||
responses={
|
||||
200: {"description": "Successfully created browser session"},
|
||||
401: {"description": "Unauthorized - Invalid or missing authentication"},
|
||||
403: {"description": "Unauthorized - Invalid or missing authentication"},
|
||||
},
|
||||
)
|
||||
async def create_browser_session(
|
||||
@@ -101,7 +101,7 @@ async def create_browser_session(
|
||||
summary="Close a browser session",
|
||||
responses={
|
||||
200: {"description": "Successfully closed browser session"},
|
||||
401: {"description": "Unauthorized - Invalid or missing authentication"},
|
||||
403: {"description": "Unauthorized - Invalid or missing authentication"},
|
||||
},
|
||||
)
|
||||
async def close_browser_session(
|
||||
|
||||
3
skyvern/library/__init__.py
Normal file
3
skyvern/library/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from skyvern.library.skyvern import Skyvern
|
||||
|
||||
__all__ = ["Skyvern"]
|
||||
@@ -1,12 +1,17 @@
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any, cast
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
||||
from skyvern.client import AsyncSkyvern
|
||||
from skyvern.client.agent.types.agent_get_run_response import AgentGetRunResponse
|
||||
from skyvern.client.core.pydantic_utilities import parse_obj_as
|
||||
from skyvern.client.environment import SkyvernEnvironment
|
||||
from skyvern.client.types.task_run_response import TaskRunResponse
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import security, skyvern_context
|
||||
@@ -18,21 +23,34 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse
|
||||
from skyvern.library.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
||||
from skyvern.schemas.runs import CUA_ENGINES, ProxyLocation, RunEngine, RunType
|
||||
from skyvern.services import run_service, task_v1_service, task_v2_service
|
||||
from skyvern.utils import migrate_db
|
||||
|
||||
|
||||
class SkyvernAgent:
|
||||
class Skyvern(AsyncSkyvern):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
cdp_url: str | None = None,
|
||||
browser_path: str | None = None,
|
||||
browser_type: str | None = None,
|
||||
environment: SkyvernEnvironment = SkyvernEnvironment.PRODUCTION,
|
||||
timeout: float | None = None,
|
||||
follow_redirects: bool | None = True,
|
||||
httpx_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.client: AsyncSkyvern | None = None
|
||||
super().__init__(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
environment=environment,
|
||||
timeout=timeout,
|
||||
follow_redirects=follow_redirects,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
if base_url is None and api_key is None:
|
||||
if not os.path.exists(".env"):
|
||||
raise Exception("No .env file found. Please run 'skyvern init' first to set up your environment.")
|
||||
@@ -40,7 +58,10 @@ class SkyvernAgent:
|
||||
load_dotenv(".env")
|
||||
migrate_db()
|
||||
|
||||
self.cdp_url = cdp_url
|
||||
self._api_key = api_key
|
||||
self._cdp_url = cdp_url
|
||||
self._browser_path = browser_path
|
||||
self._browser_type = browser_type
|
||||
if browser_path:
|
||||
# TODO validate browser_path
|
||||
# Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox
|
||||
@@ -51,9 +72,9 @@ class SkyvernAgent:
|
||||
if browser_process.poll() is not None:
|
||||
raise Exception(f"Failed to open browser. browser_path: {browser_path}")
|
||||
|
||||
self.cdp_url = "http://127.0.0.1:9222"
|
||||
self._cdp_url = "http://127.0.0.1:9222"
|
||||
settings.BROWSER_TYPE = "cdp-connect"
|
||||
settings.BROWSER_REMOTE_DEBUGGING_URL = self.cdp_url
|
||||
settings.BROWSER_REMOTE_DEBUGGING_URL = self._cdp_url
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported browser or invalid path: {browser_path}. "
|
||||
@@ -65,14 +86,12 @@ class SkyvernAgent:
|
||||
# raise Exception("browser type is missing")
|
||||
browser_type = "chromium-headful"
|
||||
|
||||
self._browser_type = browser_type
|
||||
settings.BROWSER_TYPE = browser_type
|
||||
elif base_url and api_key:
|
||||
self.client = AsyncSkyvern(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif api_key:
|
||||
self._api_key = api_key
|
||||
else:
|
||||
raise ValueError("base_url and api_key must be both provided")
|
||||
raise ValueError("Initializing Skyvern failed: api_key must be provided")
|
||||
|
||||
async def get_organization(self) -> Organization:
|
||||
organization = await app.DATABASE.get_organization_by_domain("skyvern.local")
|
||||
@@ -95,7 +114,13 @@ class SkyvernAgent:
|
||||
)
|
||||
return organization
|
||||
|
||||
async def _run_task(self, organization: Organization, task: Task, max_steps: int | None = None) -> None:
|
||||
async def _run_task(
|
||||
self,
|
||||
organization: Organization,
|
||||
task: Task,
|
||||
max_steps: int | None = None,
|
||||
engine: RunEngine = RunEngine.skyvern_v1,
|
||||
) -> None:
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
@@ -127,6 +152,7 @@ class SkyvernAgent:
|
||||
task=updated_task,
|
||||
step=step,
|
||||
api_key=org_auth_token.token if org_auth_token else None,
|
||||
engine=engine,
|
||||
)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
@@ -246,17 +272,23 @@ class SkyvernAgent:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
############### officially supported interfaces ###############
|
||||
async def get_run(self, run_id: str) -> RunResponse | None:
|
||||
if not self.client:
|
||||
async def get_run(self, run_id: str) -> AgentGetRunResponse | None:
|
||||
if not self._api_key:
|
||||
organization = await self.get_organization()
|
||||
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
|
||||
get_run_internal_resp = await run_service.get_run_response(
|
||||
run_id, organization_id=organization.organization_id
|
||||
)
|
||||
if not get_run_internal_resp:
|
||||
return None
|
||||
return typing.cast(
|
||||
AgentGetRunResponse,
|
||||
parse_obj_as(
|
||||
type_=AgentGetRunResponse, # type: ignore
|
||||
object_=get_run_internal_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
run_obj = await self.client.get_run(run_id)
|
||||
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
|
||||
return TaskRunResponse.model_validate(run_obj.dict())
|
||||
elif run_obj.run_type == RunType.workflow_run:
|
||||
return WorkflowRunResponse.model_validate(run_obj.dict())
|
||||
raise ValueError(f"Invalid run type: {run_obj.run_type}")
|
||||
return await self.agent.get_run(run_id)
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
@@ -276,8 +308,8 @@ class SkyvernAgent:
|
||||
browser_session_id: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> TaskRunResponse:
|
||||
if not self.client:
|
||||
if engine == RunEngine.skyvern_v1:
|
||||
if not self._api_key:
|
||||
if engine == RunEngine.skyvern_v1 or engine in CUA_ENGINES:
|
||||
data_extraction_goal = None
|
||||
navigation_goal = prompt
|
||||
navigation_payload = None
|
||||
@@ -316,13 +348,15 @@ class SkyvernAgent:
|
||||
url_hash=url_hash,
|
||||
)
|
||||
try:
|
||||
await self._run_task(organization, created_task)
|
||||
await self._run_task(organization, created_task, engine=engine)
|
||||
run_obj = await self.get_run(run_id=created_task.task_id)
|
||||
return cast(TaskRunResponse, run_obj)
|
||||
except Exception:
|
||||
# TODO: better error handling and logging
|
||||
run_obj = await self.get_run(run_id=created_task.task_id)
|
||||
return cast(TaskRunResponse, run_obj)
|
||||
if not run_obj:
|
||||
raise Exception("Failed to get the task run after creating the task.")
|
||||
return from_run_to_task_run_response(run_obj)
|
||||
|
||||
elif engine == RunEngine.skyvern_v2:
|
||||
# initialize task v2
|
||||
organization = await self.get_organization()
|
||||
@@ -343,11 +377,13 @@ class SkyvernAgent:
|
||||
|
||||
await self._run_task_v2(organization, task_v2)
|
||||
run_obj = await self.get_run(run_id=task_v2.observer_cruise_id)
|
||||
return cast(TaskRunResponse, run_obj)
|
||||
if not run_obj:
|
||||
raise Exception("Failed to get the task run after creating the task.")
|
||||
return from_run_to_task_run_response(run_obj)
|
||||
else:
|
||||
raise ValueError("Local mode is not supported for this method")
|
||||
|
||||
task_run = await self.client.agent.run_task(
|
||||
task_run = await self.agent.run_task(
|
||||
prompt=prompt,
|
||||
engine=engine,
|
||||
url=url,
|
||||
@@ -364,26 +400,12 @@ class SkyvernAgent:
|
||||
if wait_for_completion:
|
||||
async with asyncio.timeout(timeout):
|
||||
while True:
|
||||
task_run = await self.client.agent.get_run(task_run.run_id)
|
||||
task_run = await self.agent.get_run(task_run.run_id)
|
||||
if task_run.status.is_final():
|
||||
break
|
||||
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
|
||||
return TaskRunResponse.model_validate(task_run.dict())
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
parameters: dict[str, Any],
|
||||
webhook_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
totp_url: str | None = None,
|
||||
title: str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
max_steps: int | None = None,
|
||||
wait_for_completion: bool = True,
|
||||
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
||||
browser_session_id: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")
|
||||
|
||||
def from_run_to_task_run_response(run_obj: AgentGetRunResponse) -> TaskRunResponse:
|
||||
return TaskRunResponse.model_validate(run_obj.model_dump())
|
||||
@@ -255,9 +255,9 @@ class WorkflowRunRequest(BaseModel):
|
||||
workflow_id: str = Field(
|
||||
description="ID of the workflow to run. Workflow ID starts with `wpid_`.", examples=["wpid_123"]
|
||||
)
|
||||
title: str | None = Field(default=None, description="The title for this workflow run")
|
||||
parameters: dict[str, Any] = Field(default={}, description="Parameters to pass to the workflow")
|
||||
proxy_location: ProxyLocation = Field(
|
||||
title: str | None = Field(default=None, description="The title for this workflow run")
|
||||
proxy_location: ProxyLocation | None = Field(
|
||||
default=ProxyLocation.RESIDENTIAL, description="Location of proxy to use for this workflow run"
|
||||
)
|
||||
webhook_url: str | None = Field(
|
||||
|
||||
@@ -11,6 +11,12 @@ from skyvern.services import task_v1_service, task_v2_service, workflow_service
|
||||
|
||||
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
|
||||
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
# try to see if it's a workflow run id for task v2
|
||||
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(run_id, organization_id=organization_id)
|
||||
if task_v2:
|
||||
run = await app.DATABASE.get_run(task_v2.observer_cruise_id, organization_id=organization_id)
|
||||
|
||||
if not run:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user