SDK: validation action (#4203)
This commit is contained in:
committed by
GitHub
parent
7ef48c32e0
commit
4b99cd3f45
@@ -289,6 +289,7 @@ if typing.TYPE_CHECKING:
|
||||
RunSdkActionRequestAction_AiUploadFile,
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
RunSdkActionRequestAction_Validate,
|
||||
RunSdkActionRequestAction_Prompt,
|
||||
RunSdkActionResponse,
|
||||
RunStatus,
|
||||
@@ -781,6 +782,7 @@ _dynamic_imports: typing.Dict[str, str] = {
|
||||
"RunSdkActionRequestAction_AiUploadFile": ".types",
|
||||
"RunSdkActionRequestAction_Extract": ".types",
|
||||
"RunSdkActionRequestAction_LocateElement": ".types",
|
||||
"RunSdkActionRequestAction_Validate": ".types",
|
||||
"RunSdkActionRequestAction_Prompt": ".types",
|
||||
"RunSdkActionResponse": ".types",
|
||||
"RunStatus": ".types",
|
||||
@@ -1297,6 +1299,7 @@ __all__ = [
|
||||
"RunSdkActionRequestAction_AiUploadFile",
|
||||
"RunSdkActionRequestAction_Extract",
|
||||
"RunSdkActionRequestAction_LocateElement",
|
||||
"RunSdkActionRequestAction_Validate",
|
||||
"RunSdkActionRequestAction_Prompt",
|
||||
"RunSdkActionResponse",
|
||||
"RunStatus",
|
||||
|
||||
@@ -314,6 +314,7 @@ if typing.TYPE_CHECKING:
|
||||
RunSdkActionRequestAction_AiUploadFile,
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
RunSdkActionRequestAction_Validate,
|
||||
RunSdkActionRequestAction_Prompt,
|
||||
)
|
||||
from .run_sdk_action_response import RunSdkActionResponse
|
||||
@@ -814,6 +815,7 @@ _dynamic_imports: typing.Dict[str, str] = {
|
||||
"RunSdkActionRequestAction_AiUploadFile": ".run_sdk_action_request_action",
|
||||
"RunSdkActionRequestAction_Extract": ".run_sdk_action_request_action",
|
||||
"RunSdkActionRequestAction_LocateElement": ".run_sdk_action_request_action",
|
||||
"RunSdkActionRequestAction_Validate": ".run_sdk_action_request_action",
|
||||
"RunSdkActionRequestAction_Prompt": ".run_sdk_action_request_action",
|
||||
"RunSdkActionResponse": ".run_sdk_action_response",
|
||||
"RunStatus": ".run_status",
|
||||
@@ -1319,6 +1321,7 @@ __all__ = [
|
||||
"RunSdkActionRequestAction_AiUploadFile",
|
||||
"RunSdkActionRequestAction_Extract",
|
||||
"RunSdkActionRequestAction_LocateElement",
|
||||
"RunSdkActionRequestAction_Validate",
|
||||
"RunSdkActionRequestAction_Prompt",
|
||||
"RunSdkActionResponse",
|
||||
"RunStatus",
|
||||
|
||||
@@ -163,6 +163,25 @@ class RunSdkActionRequestAction_LocateElement(UniversalBaseModel):
|
||||
extra = pydantic.Extra.allow
|
||||
|
||||
|
||||
class RunSdkActionRequestAction_Validate(UniversalBaseModel):
|
||||
"""
|
||||
The action to execute with its specific parameters
|
||||
"""
|
||||
|
||||
type: typing.Literal["validate"] = "validate"
|
||||
prompt: str
|
||||
model: typing.Optional[typing.Dict[str, typing.Any]] = None
|
||||
|
||||
if IS_PYDANTIC_V2:
|
||||
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
||||
else:
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
smart_union = True
|
||||
extra = pydantic.Extra.allow
|
||||
|
||||
|
||||
class RunSdkActionRequestAction_Prompt(UniversalBaseModel):
|
||||
"""
|
||||
The action to execute with its specific parameters
|
||||
@@ -191,5 +210,6 @@ RunSdkActionRequestAction = typing.Union[
|
||||
RunSdkActionRequestAction_AiUploadFile,
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
RunSdkActionRequestAction_Validate,
|
||||
RunSdkActionRequestAction_Prompt,
|
||||
]
|
||||
|
||||
@@ -17,6 +17,7 @@ from skyvern.forge.sdk.api.files import validate_download_url
|
||||
from skyvern.forge.sdk.api.llm.schema_validator import validate_and_fill_extraction_result
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
||||
from skyvern.schemas.workflows import BlockStatus
|
||||
from skyvern.services import script_service
|
||||
from skyvern.services.otp_service import poll_otp_value
|
||||
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||
@@ -564,6 +565,19 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
||||
print(f"{'-' * 50}\n")
|
||||
return result
|
||||
|
||||
async def ai_validate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
result = await script_service.execute_validation(
|
||||
complete_criterion=prompt,
|
||||
terminate_criterion=None,
|
||||
error_code_mapping=None,
|
||||
model=model,
|
||||
)
|
||||
return result.status == BlockStatus.completed
|
||||
|
||||
async def ai_locate_element(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -684,6 +684,48 @@ class SkyvernPage(Page):
|
||||
data = kwargs.pop("data", None)
|
||||
return await self._ai.ai_extract(prompt, schema, error_code_mapping, intention, data)
|
||||
|
||||
async def validate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: dict[str, Any] | str | None = None,
|
||||
) -> bool:
|
||||
"""Validate the current page state using AI.
|
||||
|
||||
Args:
|
||||
prompt: Validation criteria or condition to check
|
||||
model: Optional model configuration. Can be either:
|
||||
- A dict with model configuration (e.g., {"model_name": "gemini-2.5-flash-lite", "max_tokens": 2048})
|
||||
- A string with just the model name (e.g., "gpt-4")
|
||||
|
||||
Returns:
|
||||
bool: True if validation passes, False otherwise
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Simple validation
|
||||
is_valid = await page.validate("Check if the login was successful")
|
||||
|
||||
# Validation with specific model (as string)
|
||||
is_valid = await page.validate(
|
||||
"Check if the order was placed",
|
||||
model="gemini-2.5-flash-lite"
|
||||
)
|
||||
|
||||
# Validation with model config (as dict)
|
||||
is_valid = await page.validate(
|
||||
"Check if the payment completed",
|
||||
model={"model_name": "gemini-2.5-flash-lite", "max_tokens": 1024}
|
||||
)
|
||||
```
|
||||
"""
|
||||
normalized_model: dict[str, Any] | None = None
|
||||
if isinstance(model, str):
|
||||
normalized_model = {"model_name": model}
|
||||
elif model is not None:
|
||||
normalized_model = model
|
||||
|
||||
return await self._ai.ai_validate(prompt=prompt, model=normalized_model)
|
||||
|
||||
async def prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -65,6 +65,14 @@ class SkyvernPageAi(Protocol):
|
||||
"""Extract information from the page using AI."""
|
||||
...
|
||||
|
||||
async def ai_validate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Validate the current page state using AI based on the given criteria."""
|
||||
...
|
||||
|
||||
async def ai_act(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
Your are here to help the user determine if the current page has met the complete/terminte criterion. Use the criterions of complete/terminate, the content of the elements parsed from the page, the screenshots of the page, and user details to determine whether the criterions has been met.
|
||||
Your are here to help the user determine if the current page has met the complete/terminate criterion. Use the criterions of complete/terminate, the content of the elements parsed from the page, the screenshots of the page, and user details to determine whether the criterions has been met.
|
||||
|
||||
|
||||
Reply in JSON format with the following keys:
|
||||
@@ -9,7 +9,7 @@ Reply in JSON format with the following keys:
|
||||
[{
|
||||
"reasoning": str, // The reasoning behind the action. This reasoning must be user information agnostic. Mention why you chose the action type, and why you chose the element id. Keep the reasoning short and to the point.
|
||||
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
|
||||
"action_type": str, // It's a string enum: "COMPLETE", "TERMINATE". "COMPLETE" is used when the current page info has met the complete criterion. If there is no complete criterion, use "COMPLETE" as long as the page info hasn't met the terminate criterion. "TERMINATE" is used to terminate with a failure when the current page info has met the terminate criterion. It there is no terminate criterion, use "TERMINATE" as long as the page info hasn't met the complete criterion.
|
||||
"action_type": str, // It's a string enum: "COMPLETE", "TERMINATE". Use "COMPLETE" when the complete criterion is met (if provided). Use "TERMINATE" when the terminate criterion is met (if provided), or when a complete criterion is provided but not met.
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
@@ -205,6 +205,12 @@ async def run_sdk_action(
|
||||
prompt=action.prompt,
|
||||
)
|
||||
result = xpath_result
|
||||
elif action.type == "validate":
|
||||
validation_result = await page_ai.ai_validate(
|
||||
prompt=action.prompt,
|
||||
model=action.model,
|
||||
)
|
||||
result = validation_result
|
||||
elif action.type == "prompt":
|
||||
prompt_result = await page_ai.ai_prompt(
|
||||
prompt=action.prompt,
|
||||
|
||||
@@ -16,6 +16,7 @@ class SdkActionType(str, Enum):
|
||||
AI_ACT = "ai_act"
|
||||
EXTRACT = "extract"
|
||||
LOCATE_ELEMENT = "locate_element"
|
||||
VALIDATE = "validate"
|
||||
PROMPT = "prompt"
|
||||
|
||||
|
||||
@@ -152,6 +153,20 @@ class LocateElementAction(SdkActionBase):
|
||||
return None
|
||||
|
||||
|
||||
class ValidateAction(SdkActionBase):
|
||||
"""Validate action parameters."""
|
||||
|
||||
type: Literal["validate"] = "validate"
|
||||
prompt: str = Field(..., description="Validation criteria or condition to check")
|
||||
model: dict[str, Any] | None = Field(None, description="Optional model configuration")
|
||||
|
||||
def get_navigation_goal(self) -> str | None:
|
||||
return self.prompt
|
||||
|
||||
def get_navigation_payload(self) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
class PromptAction(SdkActionBase):
|
||||
"""Prompt action parameters."""
|
||||
|
||||
@@ -177,6 +192,7 @@ SdkAction = Annotated[
|
||||
ActAction,
|
||||
ExtractAction,
|
||||
LocateElementAction,
|
||||
ValidateAction,
|
||||
PromptAction,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
|
||||
@@ -12,6 +12,7 @@ from skyvern.client import (
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
RunSdkActionRequestAction_Prompt,
|
||||
RunSdkActionRequestAction_Validate,
|
||||
)
|
||||
from skyvern.config import settings
|
||||
from skyvern.core.script_generations.skyvern_page_ai import SkyvernPageAi
|
||||
@@ -176,6 +177,34 @@ class SdkSkyvernPageAi(SkyvernPageAi):
|
||||
self._browser.workflow_run_id = response.workflow_run_id
|
||||
return response.result if response.result else None
|
||||
|
||||
async def ai_validate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Validate the current page state using AI via API call."""
|
||||
|
||||
LOG.info(
|
||||
"AI validate",
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
workflow_run_id=self._browser.workflow_run_id,
|
||||
)
|
||||
|
||||
response = await self._browser.skyvern.run_sdk_action(
|
||||
url=self._page.url,
|
||||
action=RunSdkActionRequestAction_Validate(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
),
|
||||
browser_session_id=self._browser.browser_session_id,
|
||||
browser_address=self._browser.browser_address,
|
||||
workflow_run_id=self._browser.workflow_run_id,
|
||||
)
|
||||
self._browser.workflow_run_id = response.workflow_run_id
|
||||
|
||||
return bool(response.result) if response.result is not None else False
|
||||
|
||||
async def ai_act(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -60,7 +60,7 @@ from skyvern.schemas.scripts import (
|
||||
ScriptFileCreate,
|
||||
ScriptStatus,
|
||||
)
|
||||
from skyvern.schemas.workflows import BlockStatus, BlockType, FileStorageType, FileType
|
||||
from skyvern.schemas.workflows import BlockResult, BlockStatus, BlockType, FileStorageType, FileType
|
||||
from skyvern.webeye.scraper.scraped_page import ElementTreeFormat
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -1748,6 +1748,18 @@ async def validate(
|
||||
if not complete_criterion and not terminate_criterion:
|
||||
raise Exception("Both complete criterion and terminate criterion are empty")
|
||||
|
||||
result = await execute_validation(complete_criterion, terminate_criterion, error_code_mapping, label, model)
|
||||
if result.status == BlockStatus.terminated:
|
||||
raise ScriptTerminationException(result.failure_reason)
|
||||
|
||||
|
||||
async def execute_validation(
|
||||
complete_criterion: str | None,
|
||||
terminate_criterion: str | None,
|
||||
error_code_mapping: dict[str, str] | None,
|
||||
label: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> BlockResult:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label)
|
||||
validation_block = ValidationBlock(
|
||||
label=block_validation_output.label,
|
||||
@@ -1765,8 +1777,7 @@ async def validate(
|
||||
organization_id=block_validation_output.organization_id,
|
||||
browser_session_id=block_validation_output.browser_session_id,
|
||||
)
|
||||
if result.status == BlockStatus.terminated:
|
||||
raise ScriptTerminationException(result.failure_reason)
|
||||
return result
|
||||
|
||||
|
||||
async def wait(seconds: int, label: str | None = None) -> None:
|
||||
|
||||
Reference in New Issue
Block a user