SDK: text prompt (#4214)
This commit is contained in:
committed by
GitHub
parent
0f495f458e
commit
b7d08fe906
@@ -289,6 +289,7 @@ if typing.TYPE_CHECKING:
|
||||
RunSdkActionRequestAction_AiUploadFile,
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
RunSdkActionRequestAction_Prompt,
|
||||
RunSdkActionResponse,
|
||||
RunStatus,
|
||||
Script,
|
||||
@@ -780,6 +781,7 @@ _dynamic_imports: typing.Dict[str, str] = {
|
||||
"RunSdkActionRequestAction_AiUploadFile": ".types",
|
||||
"RunSdkActionRequestAction_Extract": ".types",
|
||||
"RunSdkActionRequestAction_LocateElement": ".types",
|
||||
"RunSdkActionRequestAction_Prompt": ".types",
|
||||
"RunSdkActionResponse": ".types",
|
||||
"RunStatus": ".types",
|
||||
"Script": ".types",
|
||||
@@ -1295,6 +1297,7 @@ __all__ = [
|
||||
"RunSdkActionRequestAction_AiUploadFile",
|
||||
"RunSdkActionRequestAction_Extract",
|
||||
"RunSdkActionRequestAction_LocateElement",
|
||||
"RunSdkActionRequestAction_Prompt",
|
||||
"RunSdkActionResponse",
|
||||
"RunStatus",
|
||||
"Script",
|
||||
|
||||
@@ -314,6 +314,7 @@ if typing.TYPE_CHECKING:
|
||||
RunSdkActionRequestAction_AiUploadFile,
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
RunSdkActionRequestAction_Prompt,
|
||||
)
|
||||
from .run_sdk_action_response import RunSdkActionResponse
|
||||
from .run_status import RunStatus
|
||||
@@ -813,6 +814,7 @@ _dynamic_imports: typing.Dict[str, str] = {
|
||||
"RunSdkActionRequestAction_AiUploadFile": ".run_sdk_action_request_action",
|
||||
"RunSdkActionRequestAction_Extract": ".run_sdk_action_request_action",
|
||||
"RunSdkActionRequestAction_LocateElement": ".run_sdk_action_request_action",
|
||||
"RunSdkActionRequestAction_Prompt": ".run_sdk_action_request_action",
|
||||
"RunSdkActionResponse": ".run_sdk_action_response",
|
||||
"RunStatus": ".run_status",
|
||||
"Script": ".script",
|
||||
@@ -1317,6 +1319,7 @@ __all__ = [
|
||||
"RunSdkActionRequestAction_AiUploadFile",
|
||||
"RunSdkActionRequestAction_Extract",
|
||||
"RunSdkActionRequestAction_LocateElement",
|
||||
"RunSdkActionRequestAction_Prompt",
|
||||
"RunSdkActionResponse",
|
||||
"RunStatus",
|
||||
"Script",
|
||||
|
||||
@@ -163,6 +163,26 @@ class RunSdkActionRequestAction_LocateElement(UniversalBaseModel):
|
||||
extra = pydantic.Extra.allow
|
||||
|
||||
|
||||
class RunSdkActionRequestAction_Prompt(UniversalBaseModel):
|
||||
"""
|
||||
The action to execute with its specific parameters
|
||||
"""
|
||||
|
||||
type: typing.Literal["prompt"] = "prompt"
|
||||
prompt: str
|
||||
schema: typing.Optional[typing.Dict[str, typing.Any]] = None
|
||||
model: typing.Optional[typing.Dict[str, typing.Any]] = None
|
||||
|
||||
if IS_PYDANTIC_V2:
|
||||
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
||||
else:
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
smart_union = True
|
||||
extra = pydantic.Extra.allow
|
||||
|
||||
|
||||
RunSdkActionRequestAction = typing.Union[
|
||||
RunSdkActionRequestAction_AiAct,
|
||||
RunSdkActionRequestAction_AiClick,
|
||||
@@ -171,4 +191,5 @@ RunSdkActionRequestAction = typing.Union[
|
||||
RunSdkActionRequestAction_AiUploadFile,
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
RunSdkActionRequestAction_Prompt,
|
||||
]
|
||||
|
||||
@@ -17,6 +17,7 @@ from skyvern.forge.sdk.api.files import validate_download_url
|
||||
from skyvern.forge.sdk.api.llm.schema_validator import validate_and_fill_extraction_result
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
||||
from skyvern.services import script_service
|
||||
from skyvern.services.otp_service import poll_otp_value
|
||||
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||
from skyvern.webeye.actions import handler_utils
|
||||
@@ -644,6 +645,20 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
||||
|
||||
return xpath
|
||||
|
||||
async def ai_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
schema: dict[str, Any] | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
"""Send a prompt to the LLM and get a response based on the provided schema."""
|
||||
result = await script_service.prompt(
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
model=model,
|
||||
)
|
||||
return result
|
||||
|
||||
async def ai_act(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -684,6 +684,56 @@ class SkyvernPage(Page):
|
||||
data = kwargs.pop("data", None)
|
||||
return await self._ai.ai_extract(prompt, schema, error_code_mapping, intention, data)
|
||||
|
||||
async def prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
schema: dict[str, Any] | None = None,
|
||||
model: dict[str, Any] | str | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
"""Send a prompt to the LLM and get a response based on the provided schema.
|
||||
|
||||
This method allows you to interact with the LLM directly without requiring page context.
|
||||
It's useful for making decisions, generating text, or processing information using AI.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the LLM
|
||||
schema: Optional JSON schema to structure the response. If provided, the LLM response
|
||||
will be validated against this schema.
|
||||
model: Optional model configuration. Can be either:
|
||||
- A dict with model configuration (e.g., {"model_name": "gemini-2.5-flash-lite", "max_tokens": 2048})
|
||||
- A string with just the model name (e.g., "gemini-2.5-flash-lite")
|
||||
|
||||
Returns:
|
||||
LLM response structured according to the schema if provided, or unstructured response otherwise.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Simple unstructured prompt
|
||||
response = await page.prompt("What is 2 + 2?")
|
||||
# Returns: {'llm_response': '2 + 2 equals 4.'}
|
||||
|
||||
# Structured prompt with schema
|
||||
response = await page.prompt(
|
||||
"What is 2 + 2?",
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result_number": {"type": "int"},
|
||||
"confidence": {"type": "number", "minimum": 0, "maximum": 1}
|
||||
}
|
||||
}
|
||||
)
|
||||
# Returns: {'result_number': 4, 'confidence': 1}
|
||||
```
|
||||
"""
|
||||
normalized_model: dict[str, Any] | None = None
|
||||
if isinstance(model, str):
|
||||
normalized_model = {"model_name": model}
|
||||
elif model is not None:
|
||||
normalized_model = model
|
||||
|
||||
return await self._ai.ai_prompt(prompt=prompt, schema=schema, model=normalized_model)
|
||||
|
||||
@overload
|
||||
def locator(
|
||||
self,
|
||||
|
||||
@@ -78,3 +78,12 @@ class SkyvernPageAi(Protocol):
|
||||
) -> str | None:
|
||||
"""Locate an element on the page using AI and return its XPath selector."""
|
||||
...
|
||||
|
||||
async def ai_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
schema: dict[str, Any] | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
"""Send a prompt to the LLM and get a response based on the provided schema."""
|
||||
...
|
||||
|
||||
@@ -205,6 +205,13 @@ async def run_sdk_action(
|
||||
prompt=action.prompt,
|
||||
)
|
||||
result = xpath_result
|
||||
elif action.type == "prompt":
|
||||
prompt_result = await page_ai.ai_prompt(
|
||||
prompt=action.prompt,
|
||||
schema=action.schema,
|
||||
model=action.model,
|
||||
)
|
||||
result = prompt_result
|
||||
await app.DATABASE.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=organization_id,
|
||||
|
||||
@@ -16,6 +16,7 @@ class SdkActionType(str, Enum):
|
||||
AI_ACT = "ai_act"
|
||||
EXTRACT = "extract"
|
||||
LOCATE_ELEMENT = "locate_element"
|
||||
PROMPT = "prompt"
|
||||
|
||||
|
||||
# Base action class
|
||||
@@ -151,6 +152,21 @@ class LocateElementAction(SdkActionBase):
|
||||
return None
|
||||
|
||||
|
||||
class PromptAction(SdkActionBase):
|
||||
"""Prompt action parameters."""
|
||||
|
||||
type: Literal["prompt"] = "prompt"
|
||||
prompt: str = Field(..., description="The prompt to send to the LLM")
|
||||
schema: dict[str, Any] | None = Field(None, description="Optional JSON schema to structure the response")
|
||||
model: dict[str, Any] | None = Field(None, description="Optional model configuration")
|
||||
|
||||
def get_navigation_goal(self) -> str | None:
|
||||
return self.prompt
|
||||
|
||||
def get_navigation_payload(self) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
# Discriminated union of all action types
|
||||
SdkAction = Annotated[
|
||||
Union[
|
||||
@@ -161,6 +177,7 @@ SdkAction = Annotated[
|
||||
ActAction,
|
||||
ExtractAction,
|
||||
LocateElementAction,
|
||||
PromptAction,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
@@ -11,6 +11,7 @@ from skyvern.client import (
|
||||
RunSdkActionRequestAction_AiUploadFile,
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
RunSdkActionRequestAction_Prompt,
|
||||
)
|
||||
from skyvern.config import settings
|
||||
from skyvern.core.script_generations.skyvern_page_ai import SkyvernPageAi
|
||||
@@ -225,3 +226,33 @@ class SdkSkyvernPageAi(SkyvernPageAi):
|
||||
return response.result
|
||||
|
||||
return None
|
||||
|
||||
async def ai_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
schema: dict[str, Any] | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
"""Send a prompt to the LLM and get a response based on the provided schema via API call."""
|
||||
|
||||
LOG.info(
|
||||
"AI prompt",
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
workflow_run_id=self._browser.workflow_run_id,
|
||||
)
|
||||
|
||||
response = await self._browser.skyvern.run_sdk_action(
|
||||
url=self._page.url,
|
||||
action=RunSdkActionRequestAction_Prompt(
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
model=model,
|
||||
),
|
||||
browser_session_id=self._browser.browser_session_id,
|
||||
browser_address=self._browser.browser_address,
|
||||
workflow_run_id=self._browser.workflow_run_id,
|
||||
)
|
||||
self._browser.workflow_run_id = response.workflow_run_id
|
||||
|
||||
return response.result if response.result is not None else None
|
||||
|
||||
Reference in New Issue
Block a user