diff --git a/integrations/llama_index/pyproject.toml b/integrations/llama_index/pyproject.toml index f5398f9d..5aca9fc3 100644 --- a/integrations/llama_index/pyproject.toml +++ b/integrations/llama_index/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "skyvern-llamaindex" -version = "0.2.0" +version = "0.2.1" description = "Skyvern integration for LlamaIndex" authors = ["lawyzheng "] packages = [{ include = "skyvern_llamaindex" }] diff --git a/integrations/llama_index/skyvern_llamaindex/agent.py b/integrations/llama_index/skyvern_llamaindex/agent.py index fadbdabd..7f1491d9 100644 --- a/integrations/llama_index/skyvern_llamaindex/agent.py +++ b/integrations/llama_index/skyvern_llamaindex/agent.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List from llama_index.core.tools import FunctionTool from llama_index.core.tools.tool_spec.base import SPEC_FUNCTION_TYPE, BaseToolSpec @@ -11,7 +11,7 @@ from skyvern.schemas.runs import RunEngine class SkyvernTool: - def __init__(self, agent: Optional[Skyvern] = None): + def __init__(self, agent: Skyvern | None = None): if agent is None: agent = Skyvern(base_url=None, api_key=None) self.agent = agent @@ -49,7 +49,13 @@ class SkyvernTaskToolSpec(BaseToolSpec): self.engine = engine self.run_task_timeout_seconds = run_task_timeout_seconds - async def run_task(self, user_prompt: str, url: Optional[str] = None) -> TaskRunResponse: + async def run_task( + self, + user_prompt: str | None = None, + url: str | None = None, + *_: Any, + **kw: Any, + ) -> TaskRunResponse: """ Use Skyvern agent to run a task. This function won't return until the task is finished. @@ -57,6 +63,17 @@ class SkyvernTaskToolSpec(BaseToolSpec): user_prompt[str]: The user's prompt describing the task. url (Optional[str]): The URL of the target website for the task. """ + if user_prompt is None and kw.get("args"): + user_prompt = kw["args"][0] + + if url is None: + if kw.get("args") and len(kw["args"]) > 1: + url = kw["args"][1] + elif kw.get("kwargs"): + url = kw["kwargs"].get("url") + + assert user_prompt is not None, "user_prompt is required" + return await self.agent.run_task( prompt=user_prompt, url=url, @@ -65,7 +82,13 @@ class SkyvernTaskToolSpec(BaseToolSpec): wait_for_completion=True, ) - async def dispatch_task(self, user_prompt: str, url: Optional[str] = None) -> TaskRunResponse: + async def dispatch_task( + self, + user_prompt: str | None = None, + url: str | None = None, + *_: Any, + **kw: Any, + ) -> TaskRunResponse: """ Use Skyvern agent to dispatch a task. This function will return immediately and the task will be running in the background. @@ -73,6 +96,17 @@ class SkyvernTaskToolSpec(BaseToolSpec): user_prompt[str]: The user's prompt describing the task. url (Optional[str]): The URL of the target website for the task. """ + if user_prompt is None and kw.get("args"): + user_prompt = kw["args"][0] + + if url is None: + if kw.get("args") and len(kw["args"]) > 1: + url = kw["args"][1] + elif kw.get("kwargs"): + url = kw["kwargs"].get("url") + + assert user_prompt is not None, "user_prompt is required" + return await self.agent.run_task( prompt=user_prompt, url=url, @@ -81,11 +115,15 @@ class SkyvernTaskToolSpec(BaseToolSpec): wait_for_completion=False, ) - async def get_task(self, task_id: str) -> GetRunResponse | None: + async def get_task(self, task_id: str | None = None, *_: Any, **kwargs: Any) -> GetRunResponse | None: """ Use Skyvern agent to get a task. Args: task_id[str]: The id of the task. """ + if task_id is None and "args" in kwargs: + task_id = kwargs["args"][0] + + assert task_id is not None, "task_id is required" return await self.agent.get_run(run_id=task_id) diff --git a/integrations/llama_index/skyvern_llamaindex/client.py b/integrations/llama_index/skyvern_llamaindex/client.py index f05e338c..a7f74867 100644 --- a/integrations/llama_index/skyvern_llamaindex/client.py +++ b/integrations/llama_index/skyvern_llamaindex/client.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List from llama_index.core.tools import FunctionTool from llama_index.core.tools.tool_spec.base import SPEC_FUNCTION_TYPE, BaseToolSpec @@ -59,7 +59,13 @@ class SkyvernTaskToolSpec(BaseToolSpec): self.run_task_timeout_seconds = run_task_timeout_seconds self.client = Skyvern(base_url=base_url, api_key=api_key) - async def run_task(self, user_prompt: str, url: Optional[str] = None) -> TaskRunResponse: + async def run_task( + self, + user_prompt: str | None = None, + url: str | None = None, + *_: Any, + **kw: Any, + ) -> TaskRunResponse: """ Use Skyvern client to run a task. This function won't return until the task is finished. @@ -67,6 +73,16 @@ class SkyvernTaskToolSpec(BaseToolSpec): user_prompt[str]: The user's prompt describing the task. url (Optional[str]): The URL of the target website for the task. """ + if user_prompt is None and kw.get("args"): + user_prompt = kw["args"][0] + + if url is None: + if kw.get("args") and len(kw["args"]) > 1: + url = kw["args"][1] + elif kw.get("kwargs"): + url = kw["kwargs"].get("url") + + assert user_prompt is not None, "user_prompt is required" return await self.client.run_task( prompt=user_prompt, @@ -76,7 +92,13 @@ class SkyvernTaskToolSpec(BaseToolSpec): wait_for_completion=True, ) - async def dispatch_task(self, user_prompt: str, url: Optional[str] = None) -> TaskRunResponse: + async def dispatch_task( + self, + user_prompt: str | None = None, + url: str | None = None, + *_: Any, + **kw: Any, + ) -> TaskRunResponse: """ Use Skyvern client to dispatch a task. This function will return immediately and the task will be running in the background. @@ -84,7 +106,16 @@ class SkyvernTaskToolSpec(BaseToolSpec): user_prompt[str]: The user's prompt describing the task. url (Optional[str]): The URL of the target website for the task. """ + if user_prompt is None and kw.get("args"): + user_prompt = kw["args"][0] + if url is None: + if kw.get("args") and len(kw["args"]) > 1: + url = kw["args"][1] + elif kw.get("kwargs"): + url = kw["kwargs"].get("url") + + assert user_prompt is not None, "user_prompt is required" return await self.client.run_task( prompt=user_prompt, url=url, @@ -93,12 +124,15 @@ class SkyvernTaskToolSpec(BaseToolSpec): wait_for_completion=False, ) - async def get_task(self, task_id: str) -> GetRunResponse | None: + async def get_task(self, task_id: str | None = None, *_: Any, **kwargs: Any) -> GetRunResponse | None: """ Use Skyvern client to get a task. Args: task_id[str]: The id of the task. """ + if task_id is None and "args" in kwargs: + task_id = kwargs["args"][0] + assert task_id is not None, "task_id is required" return await self.client.get_run(run_id=task_id)