Files
Dorod-Sky/integrations/langchain/skyvern_langchain/client.py
2025-03-03 23:05:35 +08:00

133 lines
5.8 KiB
Python

from typing import Any, Dict, Literal, Type
from httpx import AsyncClient
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from skyvern_langchain.schema import CreateTaskInput, GetTaskInput
from skyvern_langchain.settings import settings
from skyvern.client import AsyncSkyvern
from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskRequest, TaskResponse
class SkyvernTaskBaseTool(BaseTool):
api_key: str = Field(default=settings.api_key)
base_url: str = Field(default=settings.base_url)
engine: Literal["TaskV1", "TaskV2"] = Field(default=settings.engine)
run_task_timeout_seconds: int = Field(default=settings.run_task_timeout_seconds)
def get_client(self) -> AsyncSkyvern:
httpx_client = AsyncClient(
headers={
"Content-Type": "application/json",
"x-api-key": self.api_key,
},
)
return AsyncSkyvern(base_url=self.base_url, httpx_client=httpx_client)
def _run(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("skyvern task tool does not support sync")
class RunTask(SkyvernTaskBaseTool):
name: str = "run-skyvern-client-task"
description: str = """Use Skyvern client to run a task. This function won't return until the task is finished."""
args_schema: Type[BaseModel] = CreateTaskInput
async def _arun(self, user_prompt: str, url: str | None = None) -> TaskResponse | Dict[str, Any | None]:
if self.engine == "TaskV1":
return await self._arun_task_v1(user_prompt=user_prompt, url=url)
else:
return await self._arun_task_v2(user_prompt=user_prompt, url=url)
async def _arun_task_v1(self, user_prompt: str, url: str | None = None) -> TaskResponse:
task_generation = await self.get_client().agent.generate_task(
prompt=user_prompt,
)
task_request = TaskRequest.model_validate(task_generation, from_attributes=True)
if url is not None:
task_request.url = url
return await self.get_client().agent.run_task(
timeout_seconds=self.run_task_timeout_seconds,
url=task_request.url,
title=task_request.title,
navigation_goal=task_request.navigation_goal,
data_extraction_goal=task_request.data_extraction_goal,
navigation_payload=task_request.navigation_goal,
error_code_mapping=task_request.error_code_mapping,
extracted_information_schema=task_request.extracted_information_schema,
complete_criterion=task_request.complete_criterion,
terminate_criterion=task_request.terminate_criterion,
)
async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskResponse:
task_request = ObserverTaskRequest(url=url, user_prompt=user_prompt)
return await self.get_client().agent.run_observer_task_v_2(
timeout_seconds=self.run_task_timeout_seconds,
user_prompt=task_request.user_prompt,
url=task_request.url,
browser_session_id=task_request.browser_session_id,
)
class DispatchTask(SkyvernTaskBaseTool):
name: str = "dispatch-skyvern-client-task"
description: str = """Use Skyvern client to dispatch a task. This function will return immediately and the task will be running in the background."""
args_schema: Type[BaseModel] = CreateTaskInput
async def _arun(self, user_prompt: str, url: str | None = None) -> CreateTaskResponse | Dict[str, Any | None]:
if self.engine == "TaskV1":
return await self._arun_task_v1(user_prompt=user_prompt, url=url)
else:
return await self._arun_task_v2(user_prompt=user_prompt, url=url)
async def _arun_task_v1(self, user_prompt: str, url: str | None = None) -> CreateTaskResponse:
task_generation = await self.get_client().agent.generate_task(
prompt=user_prompt,
)
task_request = TaskRequest.model_validate(task_generation, from_attributes=True)
if url is not None:
task_request.url = url
return await self.get_client().agent.create_task(
url=task_request.url,
title=task_request.title,
navigation_goal=task_request.navigation_goal,
data_extraction_goal=task_request.data_extraction_goal,
navigation_payload=task_request.navigation_goal,
error_code_mapping=task_request.error_code_mapping,
extracted_information_schema=task_request.extracted_information_schema,
complete_criterion=task_request.complete_criterion,
terminate_criterion=task_request.terminate_criterion,
)
async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> Dict[str, Any | None]:
task_request = ObserverTaskRequest(url=url, user_prompt=user_prompt)
return await self.get_client().agent.observer_task_v_2(
user_prompt=task_request.user_prompt,
url=task_request.url,
browser_session_id=task_request.browser_session_id,
)
class GetTask(SkyvernTaskBaseTool):
name: str = "get-skyvern-client-task"
description: str = """Use Skyvern client to get a task."""
args_schema: Type[BaseModel] = GetTaskInput
async def _arun(self, task_id: str) -> Dict[str, Any | None]:
if self.engine == "TaskV1":
return await self._arun_task_v1(task_id=task_id)
else:
return await self._arun_task_v2(task_id=task_id)
async def _arun_task_v1(self, task_id: str) -> TaskResponse:
return await self.get_client().agent.get_task(task_id=task_id)
async def _arun_task_v2(self, task_id: str) -> Dict[str, Any | None]:
return await self.get_client().agent.get_observer_task_v_2(task_id=task_id)