SDK: Prompt-based locator (#4027)
This commit is contained in:
committed by
GitHub
parent
90f51bcacb
commit
8fb46ef1ca
123
skyvern/library/ai_locator.py
Normal file
123
skyvern/library/ai_locator.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from playwright.async_api import Locator, Page
|
||||
|
||||
from skyvern.core.script_generations.skyvern_page_ai import SkyvernPageAi
|
||||
|
||||
LOCATOR_CHAIN_METHODS = {
|
||||
"nth",
|
||||
"first",
|
||||
"last",
|
||||
"locator",
|
||||
"filter",
|
||||
"and_",
|
||||
"or_",
|
||||
"frame_locator",
|
||||
"get_by_alt_text",
|
||||
"get_by_label",
|
||||
"get_by_placeholder",
|
||||
"get_by_role",
|
||||
"get_by_test_id",
|
||||
"get_by_text",
|
||||
"get_by_title",
|
||||
}
|
||||
|
||||
|
||||
class AILocator(Locator):
|
||||
"""A lazy proxy that acts like a Playwright Locator but resolves XPath via AI on first use.
|
||||
|
||||
This class defers the AI call until an actual Playwright method is invoked,
|
||||
allowing the locator to be created synchronously while the AI resolution happens asynchronously.
|
||||
|
||||
Supports fallback to a selector if AI resolution fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page: Page,
|
||||
page_ai: SkyvernPageAi,
|
||||
prompt: str,
|
||||
selector: str | None = None,
|
||||
selector_kwargs: dict[str, Any] | None = None,
|
||||
try_selector_first: bool = True,
|
||||
parent_resolver: Callable[[], Any] | None = None,
|
||||
):
|
||||
super().__init__(page)
|
||||
self._page = page
|
||||
self._page_ai = page_ai
|
||||
self._prompt = prompt
|
||||
self._selector = selector
|
||||
self._selector_kwargs = selector_kwargs or {}
|
||||
self._resolved_locator: Locator | None = None
|
||||
self._try_selector_first = try_selector_first
|
||||
|
||||
# For chaining: store a resolver function that returns the final Locator
|
||||
self._parent_resolver = parent_resolver
|
||||
|
||||
async def _resolve(self) -> Locator:
|
||||
if self._resolved_locator is None:
|
||||
if self._parent_resolver:
|
||||
self._resolved_locator = await self._parent_resolver()
|
||||
else:
|
||||
if self._try_selector_first and self._selector:
|
||||
try:
|
||||
selector_locator = self._page.locator(self._selector, **self._selector_kwargs)
|
||||
count = await selector_locator.count()
|
||||
if count > 0:
|
||||
self._resolved_locator = selector_locator
|
||||
return self._resolved_locator
|
||||
except Exception:
|
||||
# Selector failed, will try AI below
|
||||
pass
|
||||
|
||||
try:
|
||||
xpath = await self._page_ai.ai_locate_element(prompt=self._prompt)
|
||||
if not xpath:
|
||||
raise ValueError(f"AI failed to locate element with prompt: {self._prompt}")
|
||||
|
||||
self._resolved_locator = self._page.locator(
|
||||
xpath if xpath.startswith(("xpath=", "css=", "text=", "role=", "id=")) else f"xpath={xpath}"
|
||||
)
|
||||
except Exception as e:
|
||||
if self._selector and not self._try_selector_first:
|
||||
self._resolved_locator = self._page.locator(self._selector, **self._selector_kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return self._resolved_locator
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
# Locator chaining method
|
||||
if name in LOCATOR_CHAIN_METHODS:
|
||||
|
||||
def locator_chain_wrapper(*args: Any, **kwargs: Any) -> AILocator:
|
||||
async def resolver() -> Locator:
|
||||
parent_locator = await self._resolve()
|
||||
method = getattr(parent_locator, name)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return AILocator(
|
||||
page=self._page,
|
||||
page_ai=self._page_ai,
|
||||
prompt=self._prompt,
|
||||
selector=self._selector,
|
||||
selector_kwargs=self._selector_kwargs,
|
||||
try_selector_first=self._try_selector_first,
|
||||
parent_resolver=resolver,
|
||||
)
|
||||
|
||||
return locator_chain_wrapper
|
||||
|
||||
# For all other methods (async actions like click, fill, etc.)
|
||||
async def async_method_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
locator = await self._resolve()
|
||||
method = getattr(locator, name)
|
||||
result = method(*args, **kwargs)
|
||||
return await result
|
||||
|
||||
return async_method_wrapper
|
||||
@@ -10,6 +10,7 @@ from skyvern.client import (
|
||||
RunSdkActionRequestAction_AiSelectOption,
|
||||
RunSdkActionRequestAction_AiUploadFile,
|
||||
RunSdkActionRequestAction_Extract,
|
||||
RunSdkActionRequestAction_LocateElement,
|
||||
)
|
||||
from skyvern.config import settings
|
||||
from skyvern.core.script_generations.skyvern_page_ai import SkyvernPageAi
|
||||
@@ -192,3 +193,35 @@ class SdkSkyvernPageAi(SkyvernPageAi):
|
||||
workflow_run_id=self._browser.workflow_run_id,
|
||||
)
|
||||
self._browser.workflow_run_id = response.workflow_run_id
|
||||
|
||||
async def ai_locate_element(
|
||||
self,
|
||||
prompt: str,
|
||||
) -> str | None:
|
||||
"""Locate an element on the page using AI and return its XPath selector via API call.
|
||||
|
||||
Args:
|
||||
prompt: Natural language description of the element to locate (e.g., 'find "download invoices" button')
|
||||
|
||||
Returns:
|
||||
XPath selector string (e.g., 'xpath=//button[@id="download"]') or None if not found
|
||||
"""
|
||||
|
||||
LOG.info("AI locate element", prompt=prompt, workflow_run_id=self._browser.workflow_run_id)
|
||||
|
||||
response = await self._browser.skyvern.run_sdk_action(
|
||||
url=self._page.url,
|
||||
action=RunSdkActionRequestAction_LocateElement(
|
||||
prompt=prompt,
|
||||
),
|
||||
browser_session_id=self._browser.browser_session_id,
|
||||
browser_address=self._browser.browser_address,
|
||||
workflow_run_id=self._browser.workflow_run_id,
|
||||
)
|
||||
self._browser.workflow_run_id = response.workflow_run_id
|
||||
|
||||
# Return the XPath result directly
|
||||
if response.result and isinstance(response.result, str):
|
||||
return response.result
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user