update langchain integration (#2388)
This commit is contained in:
@@ -42,7 +42,7 @@ Go to [Langchain Tools](https://python.langchain.com/v0.1/docs/modules/tools/) t
|
||||
### Run a task(sync) locally in your local environment
|
||||
> sync task won't return until the task is finished.
|
||||
|
||||
:warning: :warning: if you want to run this code block, you need to run `skyvern init --openai-api-key <your_openai_api_key>` command in your terminal to set up skyvern first.
|
||||
:warning: :warning: if you want to run this code block, you need to run `skyvern init` command in your terminal to set up skyvern first.
|
||||
|
||||
|
||||
```python
|
||||
@@ -65,7 +65,7 @@ if __name__ == "__main__":
|
||||
|
||||
:warning: :warning: if you want to run the task in the background, you need to keep the script running until the task is finished, otherwise the task will be killed when the script is finished.
|
||||
|
||||
:warning: :warning: if you want to run this code block, you need to run `skyvern init --openai-api-key <your_openai_api_key>` command in your terminal to set up skyvern first.
|
||||
:warning: :warning: if you want to run this code block, you need to run `skyvern init` command in your terminal to set up skyvern first.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
@@ -88,7 +88,7 @@ if __name__ == "__main__":
|
||||
|
||||
### Get a task locally in your local environment
|
||||
|
||||
:warning: :warning: if you want to run this code block, you need to run `skyvern init --openai-api-key <your_openai_api_key>` command in your terminal to set up skyvern first.
|
||||
:warning: :warning: if you want to run this code block, you need to run `skyvern init` command in your terminal to set up skyvern first.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
@@ -191,7 +191,7 @@ The following two examples show how to build an agent that executes a specified
|
||||
|
||||
> async task will return immediately and the task will be running in the background. You can use `GetTask` tool to poll the task information until the task is finished.
|
||||
|
||||
:warning: :warning: if you want to run this code block, you need to run `skyvern init --openai-api-key <your_openai_api_key>` command in your terminal to set up skyvern first.
|
||||
:warning: :warning: if you want to run this code block, you need to run `skyvern init` command in your terminal to set up skyvern first.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
1125
integrations/langchain/poetry.lock
generated
1125
integrations/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "skyvern-langchain"
|
||||
version = "0.1.6"
|
||||
version = "0.1.7"
|
||||
description = ""
|
||||
authors = ["lawyzheng <lawy@skyvern.com>"]
|
||||
packages = [{ include = "skyvern_langchain" }]
|
||||
@@ -8,7 +8,7 @@ readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11,<3.12"
|
||||
skyvern = "^0.1.56"
|
||||
skyvern = ">=0.1.84"
|
||||
langchain = "^0.3.19"
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal, Type
|
||||
from typing import Any, Type
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from litellm import BaseModel
|
||||
@@ -6,52 +6,33 @@ from pydantic import Field
|
||||
from skyvern_langchain.schema import CreateTaskInput, GetTaskInput
|
||||
from skyvern_langchain.settings import settings
|
||||
|
||||
from skyvern.agent import SkyvernAgent
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGenerationBase
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskRequest, TaskResponse
|
||||
from skyvern import Skyvern
|
||||
from skyvern.client.agent.types.agent_get_run_response import AgentGetRunResponse
|
||||
from skyvern.client.types.task_run_response import TaskRunResponse
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
|
||||
|
||||
class SkyvernTaskBaseTool(BaseTool):
|
||||
engine: Literal["TaskV1", "TaskV2"] = Field(default=settings.engine)
|
||||
engine: RunEngine = Field(default=settings.engine)
|
||||
run_task_timeout_seconds: int = Field(default=settings.run_task_timeout_seconds)
|
||||
agent: SkyvernAgent = SkyvernAgent()
|
||||
agent: Skyvern = Skyvern(base_url=None, api_key=None)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("skyvern task tool does not support sync")
|
||||
|
||||
# TODO: agent haven't exposed the task v1 generate function, we can migrate to use agent interface when it's available
|
||||
async def _generate_v1_task_request(self, user_prompt: str) -> TaskGenerationBase:
|
||||
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=user_prompt)
|
||||
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
|
||||
return TaskGenerationBase.model_validate(llm_response)
|
||||
|
||||
|
||||
class RunTask(SkyvernTaskBaseTool):
|
||||
name: str = "run-skyvern-agent-task"
|
||||
description: str = """Use Skyvern agent to run a task. This function won't return until the task is finished."""
|
||||
args_schema: Type[BaseModel] = CreateTaskInput
|
||||
|
||||
async def _arun(self, user_prompt: str, url: str | None = None) -> TaskResponse | TaskV2:
|
||||
if self.engine == "TaskV1":
|
||||
return await self._arun_task_v1(user_prompt=user_prompt, url=url)
|
||||
else:
|
||||
return await self._arun_task_v2(user_prompt=user_prompt, url=url)
|
||||
|
||||
async def _arun_task_v1(self, user_prompt: str, url: str | None = None) -> TaskResponse:
|
||||
task_generation = await self._generate_v1_task_request(user_prompt=user_prompt)
|
||||
task_request = TaskRequest.model_validate(task_generation, from_attributes=True)
|
||||
if url is not None:
|
||||
task_request.url = url
|
||||
|
||||
return await self.agent.run_task_v1(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds)
|
||||
|
||||
async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskV2:
|
||||
task_request = TaskV2Request(user_prompt=user_prompt, url=url)
|
||||
return await self.agent.run_observer_task_v_2(
|
||||
task_request=task_request, timeout_seconds=self.run_task_timeout_seconds
|
||||
async def _arun(self, user_prompt: str, url: str | None = None) -> TaskRunResponse:
|
||||
return await self.agent.run_task(
|
||||
prompt=user_prompt,
|
||||
url=url,
|
||||
engine=self.engine,
|
||||
timeout=self.run_task_timeout_seconds,
|
||||
wait_for_completion=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -60,23 +41,14 @@ class DispatchTask(SkyvernTaskBaseTool):
|
||||
description: str = """Use Skyvern agent to dispatch a task. This function will return immediately and the task will be running in the background."""
|
||||
args_schema: Type[BaseModel] = CreateTaskInput
|
||||
|
||||
async def _arun(self, user_prompt: str, url: str | None = None) -> CreateTaskResponse | TaskV2:
|
||||
if self.engine == "TaskV1":
|
||||
return await self._arun_task_v1(user_prompt=user_prompt, url=url)
|
||||
else:
|
||||
return await self._arun_task_v2(user_prompt=user_prompt, url=url)
|
||||
|
||||
async def _arun_task_v1(self, user_prompt: str, url: str | None = None) -> CreateTaskResponse:
|
||||
task_generation = await self._generate_v1_task_request(user_prompt=user_prompt)
|
||||
task_request = TaskRequest.model_validate(task_generation, from_attributes=True)
|
||||
if url is not None:
|
||||
task_request.url = url
|
||||
|
||||
return await self.agent.create_task_v1(task_request=task_request)
|
||||
|
||||
async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskV2:
|
||||
task_request = TaskV2Request(user_prompt=user_prompt, url=url)
|
||||
return await self.agent.observer_task_v_2(task_request=task_request)
|
||||
async def _arun(self, user_prompt: str, url: str | None = None) -> TaskRunResponse:
|
||||
return await self.agent.run_task(
|
||||
prompt=user_prompt,
|
||||
url=url,
|
||||
engine=self.engine,
|
||||
timeout=self.run_task_timeout_seconds,
|
||||
wait_for_completion=False,
|
||||
)
|
||||
|
||||
|
||||
class GetTask(SkyvernTaskBaseTool):
|
||||
@@ -84,14 +56,5 @@ class GetTask(SkyvernTaskBaseTool):
|
||||
description: str = """Use Skyvern agent to get a task."""
|
||||
args_schema: Type[BaseModel] = GetTaskInput
|
||||
|
||||
async def _arun(self, task_id: str) -> TaskResponse | TaskV2 | None:
|
||||
if self.engine == "TaskV1":
|
||||
return await self._arun_task_v1(task_id=task_id)
|
||||
else:
|
||||
return await self._arun_task_v2(task_id=task_id)
|
||||
|
||||
async def _arun_task_v1(self, task_id: str) -> TaskResponse | None:
|
||||
return await self.agent.get_task(task_id=task_id)
|
||||
|
||||
async def _arun_task_v2(self, task_id: str) -> TaskV2 | None:
|
||||
return await self.agent.get_observer_task_v_2(task_id=task_id)
|
||||
async def _arun(self, task_id: str) -> AgentGetRunResponse | None:
|
||||
return await self.agent.get_run(run_id=task_id)
|
||||
|
||||
@@ -1,30 +1,23 @@
|
||||
from typing import Any, Dict, Literal, Type
|
||||
from typing import Any, Type
|
||||
|
||||
from httpx import AsyncClient
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from skyvern_langchain.schema import CreateTaskInput, GetTaskInput
|
||||
from skyvern_langchain.settings import settings
|
||||
|
||||
from skyvern.client import AsyncSkyvern
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskRequest, TaskResponse
|
||||
from skyvern import Skyvern
|
||||
from skyvern.client.types.task_run_response import TaskRunResponse
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
|
||||
|
||||
class SkyvernTaskBaseTool(BaseTool):
|
||||
api_key: str = Field(default=settings.api_key)
|
||||
base_url: str = Field(default=settings.base_url)
|
||||
engine: Literal["TaskV1", "TaskV2"] = Field(default=settings.engine)
|
||||
engine: RunEngine = Field(default=settings.engine)
|
||||
run_task_timeout_seconds: int = Field(default=settings.run_task_timeout_seconds)
|
||||
|
||||
def get_client(self) -> AsyncSkyvern:
|
||||
httpx_client = AsyncClient(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": self.api_key,
|
||||
},
|
||||
)
|
||||
return AsyncSkyvern(base_url=self.base_url, httpx_client=httpx_client)
|
||||
def get_client(self) -> Skyvern:
|
||||
return Skyvern(base_url=self.base_url, api_key=self.api_key)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("skyvern task tool does not support sync")
|
||||
@@ -35,41 +28,13 @@ class RunTask(SkyvernTaskBaseTool):
|
||||
description: str = """Use Skyvern client to run a task. This function won't return until the task is finished."""
|
||||
args_schema: Type[BaseModel] = CreateTaskInput
|
||||
|
||||
async def _arun(self, user_prompt: str, url: str | None = None) -> TaskResponse | Dict[str, Any | None]:
|
||||
if self.engine == "TaskV1":
|
||||
return await self._arun_task_v1(user_prompt=user_prompt, url=url)
|
||||
else:
|
||||
return await self._arun_task_v2(user_prompt=user_prompt, url=url)
|
||||
|
||||
async def _arun_task_v1(self, user_prompt: str, url: str | None = None) -> TaskResponse:
|
||||
task_generation = await self.get_client().agent.generate_task(
|
||||
async def _arun(self, user_prompt: str, url: str | None = None) -> TaskRunResponse:
|
||||
return await self.get_client().run_task(
|
||||
timeout=self.run_task_timeout_seconds,
|
||||
url=url,
|
||||
prompt=user_prompt,
|
||||
)
|
||||
|
||||
task_request = TaskRequest.model_validate(task_generation, from_attributes=True)
|
||||
if url is not None:
|
||||
task_request.url = url
|
||||
|
||||
return await self.get_client().agent.run_task_v1(
|
||||
timeout_seconds=self.run_task_timeout_seconds,
|
||||
url=task_request.url,
|
||||
title=task_request.title,
|
||||
navigation_goal=task_request.navigation_goal,
|
||||
data_extraction_goal=task_request.data_extraction_goal,
|
||||
navigation_payload=task_request.navigation_goal,
|
||||
error_code_mapping=task_request.error_code_mapping,
|
||||
extracted_information_schema=task_request.extracted_information_schema,
|
||||
complete_criterion=task_request.complete_criterion,
|
||||
terminate_criterion=task_request.terminate_criterion,
|
||||
)
|
||||
|
||||
async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskResponse:
|
||||
task_request = TaskV2Request(url=url, user_prompt=user_prompt)
|
||||
return await self.get_client().agent.run_observer_task_v_2(
|
||||
timeout_seconds=self.run_task_timeout_seconds,
|
||||
user_prompt=task_request.user_prompt,
|
||||
url=task_request.url,
|
||||
browser_session_id=task_request.browser_session_id,
|
||||
engine=self.engine,
|
||||
wait_for_completion=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,39 +43,13 @@ class DispatchTask(SkyvernTaskBaseTool):
|
||||
description: str = """Use Skyvern client to dispatch a task. This function will return immediately and the task will be running in the background."""
|
||||
args_schema: Type[BaseModel] = CreateTaskInput
|
||||
|
||||
async def _arun(self, user_prompt: str, url: str | None = None) -> CreateTaskResponse | Dict[str, Any | None]:
|
||||
if self.engine == "TaskV1":
|
||||
return await self._arun_task_v1(user_prompt=user_prompt, url=url)
|
||||
else:
|
||||
return await self._arun_task_v2(user_prompt=user_prompt, url=url)
|
||||
|
||||
async def _arun_task_v1(self, user_prompt: str, url: str | None = None) -> CreateTaskResponse:
|
||||
task_generation = await self.get_client().agent.generate_task(
|
||||
async def _arun(self, user_prompt: str, url: str | None = None) -> TaskRunResponse:
|
||||
return await self.get_client().run_task(
|
||||
timeout=self.run_task_timeout_seconds,
|
||||
url=url,
|
||||
prompt=user_prompt,
|
||||
)
|
||||
|
||||
task_request = TaskRequest.model_validate(task_generation, from_attributes=True)
|
||||
if url is not None:
|
||||
task_request.url = url
|
||||
|
||||
return await self.get_client().agent.create_task(
|
||||
url=task_request.url,
|
||||
title=task_request.title,
|
||||
navigation_goal=task_request.navigation_goal,
|
||||
data_extraction_goal=task_request.data_extraction_goal,
|
||||
navigation_payload=task_request.navigation_goal,
|
||||
error_code_mapping=task_request.error_code_mapping,
|
||||
extracted_information_schema=task_request.extracted_information_schema,
|
||||
complete_criterion=task_request.complete_criterion,
|
||||
terminate_criterion=task_request.terminate_criterion,
|
||||
)
|
||||
|
||||
async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> Dict[str, Any | None]:
|
||||
task_request = TaskV2Request(url=url, user_prompt=user_prompt)
|
||||
return await self.get_client().agent.observer_task_v_2(
|
||||
user_prompt=task_request.user_prompt,
|
||||
url=task_request.url,
|
||||
browser_session_id=task_request.browser_session_id,
|
||||
engine=self.engine,
|
||||
wait_for_completion=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -119,14 +58,5 @@ class GetTask(SkyvernTaskBaseTool):
|
||||
description: str = """Use Skyvern client to get a task."""
|
||||
args_schema: Type[BaseModel] = GetTaskInput
|
||||
|
||||
async def _arun(self, task_id: str) -> Dict[str, Any | None]:
|
||||
if self.engine == "TaskV1":
|
||||
return await self._arun_task_v1(task_id=task_id)
|
||||
else:
|
||||
return await self._arun_task_v2(task_id=task_id)
|
||||
|
||||
async def _arun_task_v1(self, task_id: str) -> TaskResponse:
|
||||
return await self.get_client().agent.get_task(task_id=task_id)
|
||||
|
||||
async def _arun_task_v2(self, task_id: str) -> Dict[str, Any | None]:
|
||||
return await self.get_client().agent.get_observer_task_v_2(task_id=task_id)
|
||||
async def _arun(self, task_id: str) -> TaskRunResponse:
|
||||
return await self.get_client().get_run(run_id=task_id)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Literal
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
api_key: str = ""
|
||||
base_url: str = "https://api.skyvern.com"
|
||||
engine: Literal["TaskV1", "TaskV2"] = "TaskV2"
|
||||
engine: RunEngine = RunEngine.skyvern_v2
|
||||
run_task_timeout_seconds: int = 60 * 60
|
||||
|
||||
class Config:
|
||||
|
||||
Reference in New Issue
Block a user