SDK: validation action (#4203)

This commit is contained in:
Stanislav Novosad
2025-12-08 13:10:30 -07:00
committed by GitHub
parent 7ef48c32e0
commit 4b99cd3f45
11 changed files with 157 additions and 5 deletions

View File

@@ -289,6 +289,7 @@ if typing.TYPE_CHECKING:
RunSdkActionRequestAction_AiUploadFile,
RunSdkActionRequestAction_Extract,
RunSdkActionRequestAction_LocateElement,
RunSdkActionRequestAction_Validate,
RunSdkActionRequestAction_Prompt,
RunSdkActionResponse,
RunStatus,
@@ -781,6 +782,7 @@ _dynamic_imports: typing.Dict[str, str] = {
"RunSdkActionRequestAction_AiUploadFile": ".types",
"RunSdkActionRequestAction_Extract": ".types",
"RunSdkActionRequestAction_LocateElement": ".types",
"RunSdkActionRequestAction_Validate": ".types",
"RunSdkActionRequestAction_Prompt": ".types",
"RunSdkActionResponse": ".types",
"RunStatus": ".types",
@@ -1297,6 +1299,7 @@ __all__ = [
"RunSdkActionRequestAction_AiUploadFile",
"RunSdkActionRequestAction_Extract",
"RunSdkActionRequestAction_LocateElement",
"RunSdkActionRequestAction_Validate",
"RunSdkActionRequestAction_Prompt",
"RunSdkActionResponse",
"RunStatus",

View File

@@ -314,6 +314,7 @@ if typing.TYPE_CHECKING:
RunSdkActionRequestAction_AiUploadFile,
RunSdkActionRequestAction_Extract,
RunSdkActionRequestAction_LocateElement,
RunSdkActionRequestAction_Validate,
RunSdkActionRequestAction_Prompt,
)
from .run_sdk_action_response import RunSdkActionResponse
@@ -814,6 +815,7 @@ _dynamic_imports: typing.Dict[str, str] = {
"RunSdkActionRequestAction_AiUploadFile": ".run_sdk_action_request_action",
"RunSdkActionRequestAction_Extract": ".run_sdk_action_request_action",
"RunSdkActionRequestAction_LocateElement": ".run_sdk_action_request_action",
"RunSdkActionRequestAction_Validate": ".run_sdk_action_request_action",
"RunSdkActionRequestAction_Prompt": ".run_sdk_action_request_action",
"RunSdkActionResponse": ".run_sdk_action_response",
"RunStatus": ".run_status",
@@ -1319,6 +1321,7 @@ __all__ = [
"RunSdkActionRequestAction_AiUploadFile",
"RunSdkActionRequestAction_Extract",
"RunSdkActionRequestAction_LocateElement",
"RunSdkActionRequestAction_Validate",
"RunSdkActionRequestAction_Prompt",
"RunSdkActionResponse",
"RunStatus",

View File

@@ -163,6 +163,25 @@ class RunSdkActionRequestAction_LocateElement(UniversalBaseModel):
extra = pydantic.Extra.allow
class RunSdkActionRequestAction_Validate(UniversalBaseModel):
"""
The action to execute with its specific parameters
"""
type: typing.Literal["validate"] = "validate"
prompt: str
model: typing.Optional[typing.Dict[str, typing.Any]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
else:
class Config:
frozen = True
smart_union = True
extra = pydantic.Extra.allow
class RunSdkActionRequestAction_Prompt(UniversalBaseModel):
"""
The action to execute with its specific parameters
@@ -191,5 +210,6 @@ RunSdkActionRequestAction = typing.Union[
RunSdkActionRequestAction_AiUploadFile,
RunSdkActionRequestAction_Extract,
RunSdkActionRequestAction_LocateElement,
RunSdkActionRequestAction_Validate,
RunSdkActionRequestAction_Prompt,
]

View File

@@ -17,6 +17,7 @@ from skyvern.forge.sdk.api.files import validate_download_url
from skyvern.forge.sdk.api.llm.schema_validator import validate_and_fill_extraction_result
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.schemas.totp_codes import OTPType
from skyvern.schemas.workflows import BlockStatus
from skyvern.services import script_service
from skyvern.services.otp_service import poll_otp_value
from skyvern.utils.prompt_engine import load_prompt_with_elements
@@ -564,6 +565,19 @@ class RealSkyvernPageAi(SkyvernPageAi):
print(f"{'-' * 50}\n")
return result
async def ai_validate(
self,
prompt: str,
model: dict[str, Any] | None = None,
) -> bool:
result = await script_service.execute_validation(
complete_criterion=prompt,
terminate_criterion=None,
error_code_mapping=None,
model=model,
)
return result.status == BlockStatus.completed
async def ai_locate_element(
self,
prompt: str,

View File

@@ -684,6 +684,48 @@ class SkyvernPage(Page):
data = kwargs.pop("data", None)
return await self._ai.ai_extract(prompt, schema, error_code_mapping, intention, data)
async def validate(
self,
prompt: str,
model: dict[str, Any] | str | None = None,
) -> bool:
"""Validate the current page state using AI.
Args:
prompt: Validation criteria or condition to check
model: Optional model configuration. Can be either:
- A dict with model configuration (e.g., {"model_name": "gemini-2.5-flash-lite", "max_tokens": 2048})
- A string with just the model name (e.g., "gpt-4")
Returns:
bool: True if validation passes, False otherwise
Examples:
```python
# Simple validation
is_valid = await page.validate("Check if the login was successful")
# Validation with specific model (as string)
is_valid = await page.validate(
"Check if the order was placed",
model="gemini-2.5-flash-lite"
)
# Validation with model config (as dict)
is_valid = await page.validate(
"Check if the payment completed",
model={"model_name": "gemini-2.5-flash-lite", "max_tokens": 1024}
)
```
"""
normalized_model: dict[str, Any] | None = None
if isinstance(model, str):
normalized_model = {"model_name": model}
elif model is not None:
normalized_model = model
return await self._ai.ai_validate(prompt=prompt, model=normalized_model)
async def prompt(
self,
prompt: str,

View File

@@ -65,6 +65,14 @@ class SkyvernPageAi(Protocol):
"""Extract information from the page using AI."""
...
async def ai_validate(
self,
prompt: str,
model: dict[str, Any] | None = None,
) -> bool:
"""Validate the current page state using AI based on the given criteria."""
...
async def ai_act(
self,
prompt: str,

View File

@@ -1,4 +1,4 @@
Your are here to help the user determine if the current page has met the complete/terminte criterion. Use the criterions of complete/terminate, the content of the elements parsed from the page, the screenshots of the page, and user details to determine whether the criterions has been met.
Your are here to help the user determine if the current page has met the complete/terminate criterion. Use the criterions of complete/terminate, the content of the elements parsed from the page, the screenshots of the page, and user details to determine whether the criterions has been met.
Reply in JSON format with the following keys:
@@ -9,7 +9,7 @@ Reply in JSON format with the following keys:
[{
"reasoning": str, // The reasoning behind the action. This reasoning must be user information agnostic. Mention why you chose the action type, and why you chose the element id. Keep the reasoning short and to the point.
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
"action_type": str, // It's a string enum: "COMPLETE", "TERMINATE". "COMPLETE" is used when the current page info has met the complete criterion. If there is no complete criterion, use "COMPLETE" as long as the page info hasn't met the terminate criterion. "TERMINATE" is used to terminate with a failure when the current page info has met the terminate criterion. It there is no terminate criterion, use "TERMINATE" as long as the page info hasn't met the complete criterion.
"action_type": str, // It's a string enum: "COMPLETE", "TERMINATE". Use "COMPLETE" when the complete criterion is met (if provided). Use "TERMINATE" when the terminate criterion is met (if provided), or when a complete criterion is provided but not met.
}]
}

View File

@@ -205,6 +205,12 @@ async def run_sdk_action(
prompt=action.prompt,
)
result = xpath_result
elif action.type == "validate":
validation_result = await page_ai.ai_validate(
prompt=action.prompt,
model=action.model,
)
result = validation_result
elif action.type == "prompt":
prompt_result = await page_ai.ai_prompt(
prompt=action.prompt,

View File

@@ -16,6 +16,7 @@ class SdkActionType(str, Enum):
AI_ACT = "ai_act"
EXTRACT = "extract"
LOCATE_ELEMENT = "locate_element"
VALIDATE = "validate"
PROMPT = "prompt"
@@ -152,6 +153,20 @@ class LocateElementAction(SdkActionBase):
return None
class ValidateAction(SdkActionBase):
"""Validate action parameters."""
type: Literal["validate"] = "validate"
prompt: str = Field(..., description="Validation criteria or condition to check")
model: dict[str, Any] | None = Field(None, description="Optional model configuration")
def get_navigation_goal(self) -> str | None:
return self.prompt
def get_navigation_payload(self) -> dict[str, Any] | None:
return None
class PromptAction(SdkActionBase):
"""Prompt action parameters."""
@@ -177,6 +192,7 @@ SdkAction = Annotated[
ActAction,
ExtractAction,
LocateElementAction,
ValidateAction,
PromptAction,
],
Field(discriminator="type"),

View File

@@ -12,6 +12,7 @@ from skyvern.client import (
RunSdkActionRequestAction_Extract,
RunSdkActionRequestAction_LocateElement,
RunSdkActionRequestAction_Prompt,
RunSdkActionRequestAction_Validate,
)
from skyvern.config import settings
from skyvern.core.script_generations.skyvern_page_ai import SkyvernPageAi
@@ -176,6 +177,34 @@ class SdkSkyvernPageAi(SkyvernPageAi):
self._browser.workflow_run_id = response.workflow_run_id
return response.result if response.result else None
async def ai_validate(
self,
prompt: str,
model: dict[str, Any] | None = None,
) -> bool:
"""Validate the current page state using AI via API call."""
LOG.info(
"AI validate",
prompt=prompt,
model=model,
workflow_run_id=self._browser.workflow_run_id,
)
response = await self._browser.skyvern.run_sdk_action(
url=self._page.url,
action=RunSdkActionRequestAction_Validate(
prompt=prompt,
model=model,
),
browser_session_id=self._browser.browser_session_id,
browser_address=self._browser.browser_address,
workflow_run_id=self._browser.workflow_run_id,
)
self._browser.workflow_run_id = response.workflow_run_id
return bool(response.result) if response.result is not None else False
async def ai_act(
self,
prompt: str,

View File

@@ -60,7 +60,7 @@ from skyvern.schemas.scripts import (
ScriptFileCreate,
ScriptStatus,
)
from skyvern.schemas.workflows import BlockStatus, BlockType, FileStorageType, FileType
from skyvern.schemas.workflows import BlockResult, BlockStatus, BlockType, FileStorageType, FileType
from skyvern.webeye.scraper.scraped_page import ElementTreeFormat
LOG = structlog.get_logger()
@@ -1748,6 +1748,18 @@ async def validate(
if not complete_criterion and not terminate_criterion:
raise Exception("Both complete criterion and terminate criterion are empty")
result = await execute_validation(complete_criterion, terminate_criterion, error_code_mapping, label, model)
if result.status == BlockStatus.terminated:
raise ScriptTerminationException(result.failure_reason)
async def execute_validation(
complete_criterion: str | None,
terminate_criterion: str | None,
error_code_mapping: dict[str, str] | None,
label: str | None = None,
model: dict[str, Any] | None = None,
) -> BlockResult:
block_validation_output = await _validate_and_get_output_parameter(label)
validation_block = ValidationBlock(
label=block_validation_output.label,
@@ -1765,8 +1777,7 @@ async def validate(
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
if result.status == BlockStatus.terminated:
raise ScriptTerminationException(result.failure_reason)
return result
async def wait(seconds: int, label: str | None = None) -> None: