move every interface to top level and get rid of sdk client grouping (#2490)
This commit is contained in:
@@ -8,9 +8,9 @@ import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from skyvern.client import AsyncSkyvern
|
||||
from skyvern.client.agent.types.agent_get_run_response import AgentGetRunResponse
|
||||
from skyvern.client.core.pydantic_utilities import parse_obj_as
|
||||
from skyvern.client.environment import SkyvernEnvironment
|
||||
from skyvern.client.types.get_run_response import GetRunResponse
|
||||
from skyvern.client.types.task_run_response import TaskRunResponse
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
@@ -30,73 +30,6 @@ from skyvern.utils import migrate_db
|
||||
|
||||
|
||||
class Skyvern(AsyncSkyvern):
|
||||
class local:
|
||||
"""Internal namespace for local mode operations."""
|
||||
|
||||
@staticmethod
|
||||
async def run_task(
|
||||
prompt: str,
|
||||
engine: RunEngine = RunEngine.skyvern_v2,
|
||||
url: str | None = None,
|
||||
webhook_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
totp_url: str | None = None,
|
||||
title: str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
data_extraction_schema: dict[str, Any] | str | None = None,
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
max_steps: int | None = None,
|
||||
wait_for_completion: bool = True,
|
||||
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
||||
browser_session_id: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> TaskRunResponse:
|
||||
"""
|
||||
Run a task using Skyvern in local mode.
|
||||
This is a wrapper around Skyvern.run_task that ensures it's used in local mode.
|
||||
|
||||
Args:
|
||||
prompt: The prompt describing the task to run
|
||||
engine: The engine to use for running the task
|
||||
url: Optional URL to navigate to
|
||||
webhook_url: Optional webhook URL for callbacks
|
||||
totp_identifier: Optional TOTP identifier
|
||||
totp_url: Optional TOTP verification URL
|
||||
title: Optional title for the task
|
||||
error_code_mapping: Optional mapping of error codes to messages
|
||||
data_extraction_schema: Optional schema for data extraction
|
||||
proxy_location: Optional proxy location
|
||||
max_steps: Optional maximum number of steps
|
||||
wait_for_completion: Whether to wait for task completion
|
||||
timeout: Timeout in seconds
|
||||
browser_session_id: Optional browser session ID
|
||||
user_agent: Optional user agent string
|
||||
|
||||
Returns:
|
||||
TaskRunResponse: The response from running the task
|
||||
|
||||
Raises:
|
||||
ValueError: If an API key is provided (this function is for local mode only)
|
||||
"""
|
||||
skyvern = Skyvern() # Initialize in local mode (no API key)
|
||||
return await skyvern.run_task(
|
||||
prompt=prompt,
|
||||
engine=engine,
|
||||
url=url,
|
||||
webhook_url=webhook_url,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_url=totp_url,
|
||||
title=title,
|
||||
error_code_mapping=error_code_mapping,
|
||||
data_extraction_schema=data_extraction_schema,
|
||||
proxy_location=proxy_location,
|
||||
max_steps=max_steps,
|
||||
wait_for_completion=wait_for_completion,
|
||||
timeout=timeout,
|
||||
browser_session_id=browser_session_id,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -339,7 +272,7 @@ class Skyvern(AsyncSkyvern):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
############### officially supported interfaces ###############
|
||||
async def get_run(self, run_id: str) -> AgentGetRunResponse | None:
|
||||
async def get_run(self, run_id: str) -> GetRunResponse | None:
|
||||
if not self._api_key:
|
||||
organization = await self.get_organization()
|
||||
get_run_internal_resp = await run_service.get_run_response(
|
||||
@@ -348,14 +281,14 @@ class Skyvern(AsyncSkyvern):
|
||||
if not get_run_internal_resp:
|
||||
return None
|
||||
return typing.cast(
|
||||
AgentGetRunResponse,
|
||||
GetRunResponse,
|
||||
parse_obj_as(
|
||||
type_=AgentGetRunResponse, # type: ignore
|
||||
type_=GetRunResponse, # type: ignore
|
||||
object_=get_run_internal_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
return await self.agent.get_run(run_id)
|
||||
return await super().get_run(run_id)
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
@@ -450,7 +383,7 @@ class Skyvern(AsyncSkyvern):
|
||||
else:
|
||||
raise ValueError("Local mode is not supported for this method")
|
||||
|
||||
task_run = await self.agent.run_task(
|
||||
task_run = await super().run_task(
|
||||
prompt=prompt,
|
||||
engine=engine,
|
||||
url=url,
|
||||
@@ -467,12 +400,12 @@ class Skyvern(AsyncSkyvern):
|
||||
if wait_for_completion:
|
||||
async with asyncio.timeout(timeout):
|
||||
while True:
|
||||
task_run = await self.agent.get_run(task_run.run_id)
|
||||
task_run = await super().get_run(task_run.run_id)
|
||||
if RunStatus(task_run.status).is_final():
|
||||
break
|
||||
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
|
||||
return TaskRunResponse.model_validate(task_run.dict())
|
||||
|
||||
|
||||
def from_run_to_task_run_response(run_obj: AgentGetRunResponse) -> TaskRunResponse:
|
||||
def from_run_to_task_run_response(run_obj: GetRunResponse) -> TaskRunResponse:
|
||||
return TaskRunResponse.model_validate(run_obj.model_dump())
|
||||
|
||||
Reference in New Issue
Block a user