From 8a561c2fbb27d556d8d884ca2cf39c523b59e6da Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sun, 16 Mar 2025 15:46:34 -0700 Subject: [PATCH] add SkyvernClient (#1943) --- .github/workflows/ci.yml | 10 +++ .../langchain/skyvern_langchain/agent.py | 24 +++--- .../langchain/skyvern_langchain/client.py | 6 +- .../llama_index/skyvern_llamaindex/agent.py | 30 ++++---- .../llama_index/skyvern_llamaindex/client.py | 6 +- skyvern/__init__.py | 3 + skyvern/agent/__init__.py | 5 +- skyvern/agent/{local.py => agent.py} | 2 +- skyvern/agent/client.py | 73 +++++++++++++++++++ skyvern/config.py | 6 +- skyvern/exceptions.py | 6 ++ skyvern/forge/sdk/api/llm/config_registry.py | 6 +- skyvern/forge/sdk/routes/agent_protocol.py | 4 +- 13 files changed, 137 insertions(+), 44 deletions(-) rename skyvern/agent/{local.py => agent.py} (99%) create mode 100644 skyvern/agent/client.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 80f310ca..1f8f650e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -93,6 +93,16 @@ jobs: # Finally, run pre-commit. - uses: pre-commit/action@v3.0.0 + env: + ENABLE_OPENAI: "true" + OPENAI_API_KEY: "sk-dummy" + ENABLE_AZURE_GPT4O_MINI: "true" + AZURE_GPT4O_MINI_DEPLOYMENT: "dummy" + AZURE_GPT4O_MINI_API_KEY: "dummy" + AZURE_GPT4O_MINI_API_BASE: "dummy" + AZURE_GPT4O_MINI_API_VERSION: "dummy" + AWS_REGION: "us-east-1" + ENABLE_BEDROCK: "true" fe-lint-build: runs-on: ubuntu-latest diff --git a/integrations/langchain/skyvern_langchain/agent.py b/integrations/langchain/skyvern_langchain/agent.py index f9d2cd31..0f9f2c3d 100644 --- a/integrations/langchain/skyvern_langchain/agent.py +++ b/integrations/langchain/skyvern_langchain/agent.py @@ -6,20 +6,18 @@ from pydantic import Field from skyvern_langchain.schema import CreateTaskInput, GetTaskInput from skyvern_langchain.settings import settings -from skyvern.agent import Agent +from skyvern.agent import SkyvernAgent from skyvern.forge import app from skyvern.forge.prompts import prompt_engine -from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskRequest from skyvern.forge.sdk.schemas.task_generations import TaskGenerationBase +from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskRequest, TaskResponse -agent = Agent() - class SkyvernTaskBaseTool(BaseTool): engine: Literal["TaskV1", "TaskV2"] = Field(default=settings.engine) run_task_timeout_seconds: int = Field(default=settings.run_task_timeout_seconds) - agent: Agent = agent + agent: SkyvernAgent = SkyvernAgent() def _run(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("skyvern task tool does not support sync") @@ -36,7 +34,7 @@ class RunTask(SkyvernTaskBaseTool): description: str = """Use Skyvern agent to run a task. This function won't return until the task is finished.""" args_schema: Type[BaseModel] = CreateTaskInput - async def _arun(self, user_prompt: str, url: str | None = None) -> TaskResponse | ObserverTask: + async def _arun(self, user_prompt: str, url: str | None = None) -> TaskResponse | TaskV2: if self.engine == "TaskV1": return await self._arun_task_v1(user_prompt=user_prompt, url=url) else: @@ -50,8 +48,8 @@ class RunTask(SkyvernTaskBaseTool): return await self.agent.run_task(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds) - async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> ObserverTask: - task_request = ObserverTaskRequest(user_prompt=user_prompt, url=url) + async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskV2: + task_request = TaskV2Request(user_prompt=user_prompt, url=url) return await self.agent.run_observer_task_v_2( task_request=task_request, timeout_seconds=self.run_task_timeout_seconds ) @@ -62,7 +60,7 @@ class DispatchTask(SkyvernTaskBaseTool): description: str = """Use Skyvern agent to dispatch a task. This function will return immediately and the task will be running in the background.""" args_schema: Type[BaseModel] = CreateTaskInput - async def _arun(self, user_prompt: str, url: str | None = None) -> CreateTaskResponse | ObserverTask: + async def _arun(self, user_prompt: str, url: str | None = None) -> CreateTaskResponse | TaskV2: if self.engine == "TaskV1": return await self._arun_task_v1(user_prompt=user_prompt, url=url) else: @@ -76,8 +74,8 @@ class DispatchTask(SkyvernTaskBaseTool): return await self.agent.create_task(task_request=task_request) - async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> ObserverTask: - task_request = ObserverTaskRequest(user_prompt=user_prompt, url=url) + async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskV2: + task_request = TaskV2Request(user_prompt=user_prompt, url=url) return await self.agent.observer_task_v_2(task_request=task_request) @@ -86,7 +84,7 @@ class GetTask(SkyvernTaskBaseTool): description: str = """Use Skyvern agent to get a task.""" args_schema: Type[BaseModel] = GetTaskInput - async def _arun(self, task_id: str) -> TaskResponse | ObserverTask | None: + async def _arun(self, task_id: str) -> TaskResponse | TaskV2 | None: if self.engine == "TaskV1": return await self._arun_task_v1(task_id=task_id) else: @@ -95,5 +93,5 @@ class GetTask(SkyvernTaskBaseTool): async def _arun_task_v1(self, task_id: str) -> TaskResponse | None: return await self.agent.get_task(task_id=task_id) - async def _arun_task_v2(self, task_id: str) -> ObserverTask | None: + async def _arun_task_v2(self, task_id: str) -> TaskV2 | None: return await self.agent.get_observer_task_v_2(task_id=task_id) diff --git a/integrations/langchain/skyvern_langchain/client.py b/integrations/langchain/skyvern_langchain/client.py index 7ba6c18f..86406d09 100644 --- a/integrations/langchain/skyvern_langchain/client.py +++ b/integrations/langchain/skyvern_langchain/client.py @@ -7,7 +7,7 @@ from skyvern_langchain.schema import CreateTaskInput, GetTaskInput from skyvern_langchain.settings import settings from skyvern.client import AsyncSkyvern -from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest +from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskRequest, TaskResponse @@ -64,7 +64,7 @@ class RunTask(SkyvernTaskBaseTool): ) async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskResponse: - task_request = ObserverTaskRequest(url=url, user_prompt=user_prompt) + task_request = TaskV2Request(url=url, user_prompt=user_prompt) return await self.get_client().agent.run_observer_task_v_2( timeout_seconds=self.run_task_timeout_seconds, user_prompt=task_request.user_prompt, @@ -106,7 +106,7 @@ class DispatchTask(SkyvernTaskBaseTool): ) async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> Dict[str, Any | None]: - task_request = ObserverTaskRequest(url=url, user_prompt=user_prompt) + task_request = TaskV2Request(url=url, user_prompt=user_prompt) return await self.get_client().agent.observer_task_v_2( user_prompt=task_request.user_prompt, url=task_request.url, diff --git a/integrations/llama_index/skyvern_llamaindex/agent.py b/integrations/llama_index/skyvern_llamaindex/agent.py index b1612479..c79ec3f4 100644 --- a/integrations/llama_index/skyvern_llamaindex/agent.py +++ b/integrations/llama_index/skyvern_llamaindex/agent.py @@ -4,20 +4,18 @@ from llama_index.core.tools import FunctionTool from llama_index.core.tools.tool_spec.base import SPEC_FUNCTION_TYPE, BaseToolSpec from skyvern_llamaindex.settings import settings -from skyvern.agent import Agent +from skyvern.agent import SkyvernAgent from skyvern.forge import app from skyvern.forge.prompts import prompt_engine -from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskRequest from skyvern.forge.sdk.schemas.task_generations import TaskGenerationBase +from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskRequest, TaskResponse -default_agent = Agent() - class SkyvernTool: - def __init__(self, agent: Optional[Agent] = None): + def __init__(self, agent: Optional[SkyvernAgent] = None): if agent is None: - agent = default_agent + agent = SkyvernAgent() self.agent = agent def run_task(self) -> FunctionTool: @@ -43,12 +41,12 @@ class SkyvernTaskToolSpec(BaseToolSpec): def __init__( self, *, - agent: Optional[Agent] = None, + agent: SkyvernAgent | None = None, engine: Literal["TaskV1", "TaskV2"] = settings.engine, run_task_timeout_seconds: int = settings.run_task_timeout_seconds, ) -> None: if agent is None: - agent = Agent() + agent = SkyvernAgent() self.agent = agent self.engine = engine self.run_task_timeout_seconds = run_task_timeout_seconds @@ -59,7 +57,7 @@ class SkyvernTaskToolSpec(BaseToolSpec): llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task") return TaskGenerationBase.model_validate(llm_response) - async def run_task(self, user_prompt: str, url: Optional[str] = None) -> TaskResponse | ObserverTask: + async def run_task(self, user_prompt: str, url: Optional[str] = None) -> TaskResponse | TaskV2: """ Use Skyvern agent to run a task. This function won't return until the task is finished. @@ -73,7 +71,7 @@ class SkyvernTaskToolSpec(BaseToolSpec): else: return await self.run_task_v2(user_prompt=user_prompt, url=url) - async def dispatch_task(self, user_prompt: str, url: Optional[str] = None) -> CreateTaskResponse | ObserverTask: + async def dispatch_task(self, user_prompt: str, url: Optional[str] = None) -> CreateTaskResponse | TaskV2: """ Use Skyvern agent to dispatch a task. This function will return immediately and the task will be running in the background. @@ -87,7 +85,7 @@ class SkyvernTaskToolSpec(BaseToolSpec): else: return await self.dispatch_task_v2(user_prompt=user_prompt, url=url) - async def get_task(self, task_id: str) -> TaskResponse | ObserverTask | None: + async def get_task(self, task_id: str) -> TaskResponse | TaskV2 | None: """ Use Skyvern agent to get a task. @@ -119,15 +117,15 @@ class SkyvernTaskToolSpec(BaseToolSpec): async def get_task_v1(self, task_id: str) -> TaskResponse | None: return await self.agent.get_task(task_id=task_id) - async def run_task_v2(self, user_prompt: str, url: Optional[str] = None) -> ObserverTask: - task_request = ObserverTaskRequest(user_prompt=user_prompt, url=url) + async def run_task_v2(self, user_prompt: str, url: Optional[str] = None) -> TaskV2: + task_request = TaskV2Request(user_prompt=user_prompt, url=url) return await self.agent.run_observer_task_v_2( task_request=task_request, timeout_seconds=self.run_task_timeout_seconds ) - async def dispatch_task_v2(self, user_prompt: str, url: Optional[str] = None) -> ObserverTask: - task_request = ObserverTaskRequest(user_prompt=user_prompt, url=url) + async def dispatch_task_v2(self, user_prompt: str, url: Optional[str] = None) -> TaskV2: + task_request = TaskV2Request(user_prompt=user_prompt, url=url) return await self.agent.observer_task_v_2(task_request=task_request) - async def get_task_v2(self, task_id: str) -> ObserverTask | None: + async def get_task_v2(self, task_id: str) -> TaskV2 | None: return await self.agent.get_observer_task_v_2(task_id=task_id) diff --git a/integrations/llama_index/skyvern_llamaindex/client.py b/integrations/llama_index/skyvern_llamaindex/client.py index a5e1c44d..a07ffc26 100644 --- a/integrations/llama_index/skyvern_llamaindex/client.py +++ b/integrations/llama_index/skyvern_llamaindex/client.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from skyvern_llamaindex.settings import settings from skyvern.client import AsyncSkyvern -from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest +from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskRequest, TaskResponse @@ -153,7 +153,7 @@ class SkyvernTaskToolSpec(BaseToolSpec): return await self.client.agent.get_task(task_id=task_id) async def run_task_v2(self, user_prompt: str, url: Optional[str] = None) -> Dict[str, Any | None]: - task_request = ObserverTaskRequest(url=url, user_prompt=user_prompt) + task_request = TaskV2Request(url=url, user_prompt=user_prompt) return await self.client.agent.run_observer_task_v_2( timeout_seconds=self.run_task_timeout_seconds, user_prompt=task_request.user_prompt, @@ -162,7 +162,7 @@ class SkyvernTaskToolSpec(BaseToolSpec): ) async def dispatch_task_v2(self, user_prompt: str, url: Optional[str] = None) -> Dict[str, Any | None]: - task_request = ObserverTaskRequest(url=url, user_prompt=user_prompt) + task_request = TaskV2Request(url=url, user_prompt=user_prompt) return await self.client.agent.observer_task_v_2( user_prompt=task_request.user_prompt, url=task_request.url, diff --git a/skyvern/__init__.py b/skyvern/__init__.py index 502cde74..2a7dbb85 100644 --- a/skyvern/__init__.py +++ b/skyvern/__init__.py @@ -1,6 +1,7 @@ from ddtrace import tracer from ddtrace.filters import FilterRequestsOnUrl +from skyvern.agent import SkyvernAgent, SkyvernClient from skyvern.forge.sdk.forge_log import setup_logger tracer.configure( @@ -11,3 +12,5 @@ tracer.configure( }, ) setup_logger() + +__all__ = ["SkyvernAgent", "SkyvernClient"] diff --git a/skyvern/agent/__init__.py b/skyvern/agent/__init__.py index 307f99dd..10085206 100644 --- a/skyvern/agent/__init__.py +++ b/skyvern/agent/__init__.py @@ -1,3 +1,4 @@ -from skyvern.agent.local import Agent +from skyvern.agent.agent import SkyvernAgent +from skyvern.agent.client import SkyvernClient -__all__ = ["Agent"] +__all__ = ["SkyvernAgent", "SkyvernClient"] diff --git a/skyvern/agent/local.py b/skyvern/agent/agent.py similarity index 99% rename from skyvern/agent/local.py rename to skyvern/agent/agent.py index c9e067c6..14fa36cb 100644 --- a/skyvern/agent/local.py +++ b/skyvern/agent/agent.py @@ -15,7 +15,7 @@ from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus from skyvern.utils import migrate_db -class Agent: +class SkyvernAgent: def __init__(self) -> None: load_dotenv(".env") migrate_db() diff --git a/skyvern/agent/client.py b/skyvern/agent/client.py new file mode 100644 index 00000000..89799d01 --- /dev/null +++ b/skyvern/agent/client.py @@ -0,0 +1,73 @@ +from enum import StrEnum + +import httpx + +from skyvern.config import settings +from skyvern.exceptions import SkyvernClientException +from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse +from skyvern.forge.sdk.schemas.tasks import ProxyLocation +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatusResponse + + +class RunEngine(StrEnum): + skyvern_v1 = "skyvern-1.0" + skyvern_v2 = "skyvern-2.0" + + +class SkyvernClient: + def __init__( + self, + base_url: str = settings.SKYVERN_BASE_URL, + api_key: str = settings.SKYVERN_API_KEY, + ) -> None: + self.base_url = base_url + self.api_key = api_key + + async def run_task( + self, + goal: str, + engine: RunEngine = RunEngine.skyvern_v1, + url: str | None = None, + webhook_url: str | None = None, + totp_identifier: str | None = None, + totp_url: str | None = None, + title: str | None = None, + error_code_mapping: dict[str, str] | None = None, + proxy_location: ProxyLocation | None = None, + max_steps: int | None = None, + ) -> TaskRunResponse: + if engine == RunEngine.skyvern_v1: + return TaskRunResponse() + elif engine == RunEngine.skyvern_v2: + return TaskRunResponse() + raise ValueError(f"Invalid engine: {engine}") + + async def run_workflow( + self, + workflow_id: str, + webhook_url: str | None = None, + proxy_location: ProxyLocation | None = None, + ) -> TaskRunResponse: + return TaskRunResponse() + + async def get_run( + self, + run_id: str, + ) -> TaskRunResponse: + return TaskRunResponse() + + async def get_workflow_run( + self, + workflow_run_id: str, + ) -> WorkflowRunStatusResponse: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/api/v1/workflows/runs/{workflow_run_id}", + headers={"x-api-key": self.api_key}, + ) + if response.status_code != 200: + raise SkyvernClientException( + f"Failed to get workflow run: {response.text}", + status_code=response.status_code, + ) + return WorkflowRunStatusResponse.model_validate(response.json()) diff --git a/skyvern/config.py b/skyvern/config.py index d89e7f19..0d47463d 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -48,8 +48,6 @@ class Settings(BaseSettings): SIGNATURE_ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7 # one week - SKYVERN_API_KEY: str = "PLACEHOLDER" - # Artifact storage settings ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts" GENERATE_PRESIGNED_URLS: bool = False @@ -166,6 +164,10 @@ class Settings(BaseSettings): ENABLE_LOG_ARTIFACTS: bool = False ENABLE_CODE_BLOCK: bool = False + # SkyvernClient Settings + SKYVERN_BASE_URL: str = "https://api.skyvern.com" + SKYVERN_API_KEY: str = "PLACEHOLDER" + def is_cloud_environment(self) -> bool: """ :return: True if env is not local, else False diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 0cf4a90b..90967b1e 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -9,6 +9,12 @@ class SkyvernException(Exception): super().__init__(message) +class SkyvernClientException(SkyvernException): + def __init__(self, message: str | None = None, status_code: int | None = None): + self.status_code = status_code + super().__init__(message) + + class SkyvernHTTPException(SkyvernException): def __init__(self, message: str | None = None, status_code: int = status.HTTP_400_BAD_REQUEST): self.status_code = status_code diff --git a/skyvern/forge/sdk/api/llm/config_registry.py b/skyvern/forge/sdk/api/llm/config_registry.py index 33e01e59..1444db70 100644 --- a/skyvern/forge/sdk/api/llm/config_registry.py +++ b/skyvern/forge/sdk/api/llm/config_registry.py @@ -5,7 +5,6 @@ from skyvern.forge.sdk.api.llm.exceptions import ( DuplicateLLMConfigError, InvalidLLMConfigError, MissingLLMProviderEnvVarsError, - NoProviderEnabledError, ) from skyvern.forge.sdk.api.llm.models import LiteLLMParams, LLMConfig, LLMRouterConfig @@ -55,7 +54,10 @@ if not any( settings.ENABLE_NOVITA, ] ): - raise NoProviderEnabledError() + LOG.warning( + "At least one LLM provider must be enabled. Run setup.sh and follow through the LLM provider setup, or " + "update the .env file (check out .env.example to see the required environment variables)." + ) if settings.ENABLE_OPENAI: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 683f1230..c173aad2 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -674,7 +674,7 @@ async def get_workflow_runs_by_id( "/workflows/{workflow_id}/runs/{workflow_run_id}/", include_in_schema=False, ) -async def get_workflow_run( +async def get_workflow_run_with_workflow_id( workflow_id: str, workflow_run_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), @@ -721,7 +721,7 @@ async def get_workflow_run_timeline( response_model=WorkflowRunStatusResponse, include_in_schema=False, ) -async def get_workflow_run_by_run_id( +async def get_workflow_run( workflow_run_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> WorkflowRunStatusResponse: