convert css shape to string (#1092)
This commit is contained in:
@@ -4,7 +4,7 @@ import hashlib
|
||||
from typing import Dict, List
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import Page
|
||||
from playwright.async_api import Frame, Page
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import SKYVERN_ID_ATTR
|
||||
@@ -19,8 +19,8 @@ from skyvern.webeye.scraper.scraper import ELEMENT_NODE_ATTRIBUTES, CleanupEleme
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
USELESS_SVG_ATTRIBUTE = [SKYVERN_ID_ATTR, "id", "aria-describedby"]
|
||||
SVG_RETRY_ATTEMPT = 3
|
||||
USELESS_SHAPE_ATTRIBUTE = [SKYVERN_ID_ATTR, "id", "aria-describedby"]
|
||||
SHAPE_CONVERTION_RETRY_ATTEMPT = 3
|
||||
|
||||
|
||||
def _remove_rect(element: dict) -> None:
|
||||
@@ -28,10 +28,45 @@ def _remove_rect(element: dict) -> None:
|
||||
del element["rect"]
|
||||
|
||||
|
||||
def _should_css_shape_convert(element: Dict) -> bool:
|
||||
if "id" not in element:
|
||||
return False
|
||||
|
||||
tag_name = element.get("tagName")
|
||||
if tag_name not in ["a", "span", "i"]:
|
||||
return False
|
||||
|
||||
# if <span> and <i> without any text in the element, we try to convert the shape
|
||||
if tag_name in ["span", "i"] and not element.get("text"):
|
||||
return True
|
||||
|
||||
# if <a>, it should be no text, no children, no href/target attribute
|
||||
if tag_name == "a":
|
||||
attributes = element.get("attributes", {})
|
||||
if element.get("text"):
|
||||
return False
|
||||
|
||||
if len(element.get("children", [])) > 0:
|
||||
return False
|
||||
|
||||
if "href" in attributes:
|
||||
return False
|
||||
|
||||
if "target" in attributes:
|
||||
return False
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_svg_cache_key(hash: str) -> str:
|
||||
return f"skyvern:svg:{hash}"
|
||||
|
||||
|
||||
def _get_shape_cache_key(hash: str) -> str:
|
||||
return f"skyvern:shape:{hash}"
|
||||
|
||||
|
||||
def _remove_skyvern_attributes(element: Dict) -> Dict:
|
||||
"""
|
||||
To get the original HTML element without skyvern attributes
|
||||
@@ -44,7 +79,7 @@ def _remove_skyvern_attributes(element: Dict) -> Dict:
|
||||
if "attributes" in element_copied:
|
||||
attributes: dict = copy.deepcopy(element_copied.get("attributes", {}))
|
||||
for key in attributes.keys():
|
||||
if key in USELESS_SVG_ATTRIBUTE:
|
||||
if key in USELESS_SHAPE_ATTRIBUTE:
|
||||
del element_copied["attributes"][key]
|
||||
|
||||
children: List[Dict] | None = element_copied.get("children", None)
|
||||
@@ -80,6 +115,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to loaded SVG cache",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
exc_info=True,
|
||||
key=svg_key,
|
||||
)
|
||||
@@ -92,6 +129,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
|
||||
LOG.warning(
|
||||
"SVG element is too large to convert, going to drop the svg element.",
|
||||
element_id=element_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
length=len(svg_html),
|
||||
)
|
||||
del element["children"]
|
||||
@@ -101,7 +140,7 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
|
||||
LOG.debug("call LLM to convert SVG to string shape", element_id=element_id)
|
||||
svg_convert_prompt = prompt_engine.load_prompt("svg-convert", svg_element=svg_html)
|
||||
|
||||
for retry in range(SVG_RETRY_ATTEMPT):
|
||||
for retry in range(SHAPE_CONVERTION_RETRY_ATTEMPT):
|
||||
try:
|
||||
json_response = await app.SECONDARY_LLM_API_HANDLER(prompt=svg_convert_prompt, step=step)
|
||||
svg_shape = json_response.get("shape", "")
|
||||
@@ -113,6 +152,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to convert SVG to string shape by secondary llm. Will retry if haven't met the max try attempt after 3s.",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
element_id=element_id,
|
||||
retry=retry,
|
||||
)
|
||||
@@ -126,6 +167,101 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
|
||||
return
|
||||
|
||||
|
||||
async def _convert_css_shape_to_string(
|
||||
task: Task, step: Step, organization: Organization | None, frame: Page | Frame, element: Dict
|
||||
) -> None:
|
||||
element_id: str = element.get("id", "")
|
||||
|
||||
shape_element = _remove_skyvern_attributes(element)
|
||||
svg_html = json_to_html(shape_element)
|
||||
hash_object = hashlib.sha256()
|
||||
hash_object.update(svg_html.encode("utf-8"))
|
||||
shape_hash = hash_object.hexdigest()
|
||||
shape_key = _get_shape_cache_key(shape_hash)
|
||||
|
||||
css_shape: str | None = None
|
||||
try:
|
||||
css_shape = await app.CACHE.get(shape_key)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to loaded CSS shape cache",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
exc_info=True,
|
||||
key=shape_key,
|
||||
)
|
||||
|
||||
if css_shape:
|
||||
LOG.debug("CSS shape loaded from cache", element_id=element_id, shape=css_shape)
|
||||
else:
|
||||
# FIXME: support element in iframe
|
||||
locater = frame.locator(f'[{SKYVERN_ID_ATTR}="{element_id}"]')
|
||||
if await locater.count() == 0:
|
||||
LOG.info(
|
||||
"No locater found to convert css shape",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
element_id=element_id,
|
||||
)
|
||||
return None
|
||||
|
||||
if await locater.count() > 1:
|
||||
LOG.info(
|
||||
"multiple locaters found to convert css shape",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
element_id=element_id,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
LOG.debug("call LLM to convert css shape to string shape", element_id=element_id)
|
||||
screenshot = await locater.screenshot(timeout=settings.BROWSER_SCREENSHOT_TIMEOUT_MS)
|
||||
prompt = prompt_engine.load_prompt("css-shape-convert")
|
||||
|
||||
for retry in range(SHAPE_CONVERTION_RETRY_ATTEMPT):
|
||||
try:
|
||||
json_response = await app.SECONDARY_LLM_API_HANDLER(
|
||||
prompt=prompt, screenshots=[screenshot], step=step
|
||||
)
|
||||
css_shape = json_response.get("shape", "")
|
||||
if not css_shape:
|
||||
raise Exception("Empty css shape replied by secondary llm")
|
||||
LOG.info("CSS Shape converted by LLM", element_id=element_id, shape=css_shape)
|
||||
await app.CACHE.set(shape_key, css_shape)
|
||||
break
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to convert css shape to string shape by secondary llm. Will retry if haven't met the max try attempt after 3s.",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
element_id=element_id,
|
||||
retry=retry,
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
else:
|
||||
LOG.info(
|
||||
"Max css shape convertion retry, going to abort the convertion.",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
element_id=element_id,
|
||||
)
|
||||
return None
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to convert css shape to string shape by LLM",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
element_id=element_id,
|
||||
)
|
||||
return None
|
||||
|
||||
if "attributes" not in element:
|
||||
element["attributes"] = dict()
|
||||
element["attributes"]["shape-description"] = css_shape
|
||||
return None
|
||||
|
||||
|
||||
class AgentFunction:
|
||||
async def validate_step_execution(
|
||||
self,
|
||||
@@ -181,7 +317,7 @@ class AgentFunction:
|
||||
step: Step,
|
||||
organization: Organization | None = None,
|
||||
) -> CleanupElementTreeFunc:
|
||||
async def cleanup_element_tree_func(url: str, element_tree: list[dict]) -> list[dict]:
|
||||
async def cleanup_element_tree_func(frame: Page | Frame, url: str, element_tree: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Remove rect and attribute.unique_id from the elements.
|
||||
The reason we're doing it is to
|
||||
@@ -197,6 +333,16 @@ class AgentFunction:
|
||||
queue_ele = queue.pop(0)
|
||||
_remove_rect(queue_ele)
|
||||
await _convert_svg_to_string(task, step, organization, queue_ele)
|
||||
|
||||
if _should_css_shape_convert(element=queue_ele):
|
||||
await _convert_css_shape_to_string(
|
||||
task=task,
|
||||
step=step,
|
||||
organization=organization,
|
||||
frame=frame,
|
||||
element=queue_ele,
|
||||
)
|
||||
|
||||
# TODO: we can come back to test removing the unique_id
|
||||
# from element attributes to make sure this won't increase hallucination
|
||||
# _remove_unique_id(queue_ele)
|
||||
|
||||
8
skyvern/forge/prompts/skyvern/css-shape-convert.j2
Normal file
8
skyvern/forge/prompts/skyvern/css-shape-convert.j2
Normal file
@@ -0,0 +1,8 @@
|
||||
You are given a screenshot of an HTML element. You need to figure out what its shape means.
|
||||
|
||||
MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc.
|
||||
Reply in JSON format with the following keys:
|
||||
{
|
||||
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
|
||||
"shape": string, // A short description of the shape of element and its meaning
|
||||
}
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Awaitable, Callable, List
|
||||
|
||||
import structlog
|
||||
from deprecation import deprecated
|
||||
from playwright.async_api import FileChooser, Locator, Page, TimeoutError
|
||||
from playwright.async_api import FileChooser, Frame, Locator, Page, TimeoutError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.constants import REPO_ROOT_DIR, SKYVERN_ID_ATTR
|
||||
@@ -165,8 +165,10 @@ def remove_exist_elements(element_tree: list[dict], check_exist: CheckExistIDFun
|
||||
def clean_and_remove_element_tree_factory(
|
||||
task: Task, step: Step, check_exist_funcs: list[CheckExistIDFunc]
|
||||
) -> CleanupElementTreeFunc:
|
||||
async def helper_func(url: str, element_tree: list[dict]) -> list[dict]:
|
||||
element_tree = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(url, element_tree)
|
||||
async def helper_func(frame: Page | Frame, url: str, element_tree: list[dict]) -> list[dict]:
|
||||
element_tree = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(
|
||||
frame, url, element_tree
|
||||
)
|
||||
for check_exist in check_exist_funcs:
|
||||
element_tree = remove_exist_elements(element_tree=element_tree, check_exist=check_exist)
|
||||
return element_tree
|
||||
@@ -1270,7 +1272,7 @@ async def choose_auto_completion_dropdown(
|
||||
|
||||
if len(confirmed_preserved_list) > 0:
|
||||
confirmed_preserved_list = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(
|
||||
skyvern_frame.get_frame().url, copy.deepcopy(confirmed_preserved_list)
|
||||
skyvern_frame.get_frame(), skyvern_frame.get_frame().url, copy.deepcopy(confirmed_preserved_list)
|
||||
)
|
||||
confirmed_preserved_list = trim_element_tree(copy.deepcopy(confirmed_preserved_list))
|
||||
|
||||
|
||||
@@ -686,6 +686,14 @@ const checkRequiredFromStyle = (element) => {
|
||||
return element.className.toLowerCase().includes("require");
|
||||
};
|
||||
|
||||
function checkDisabledFromStyle(element) {
|
||||
const className = element.className.toString().toLowerCase();
|
||||
if (className.includes("react-datepicker__day--disabled")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function getElementContext(element) {
|
||||
// dfs to collect the non unique_id context
|
||||
let fullContext = new Array();
|
||||
@@ -872,6 +880,14 @@ function buildElementObject(frame, element, interactable, purgeable = false) {
|
||||
attrs[attr.name] = attrValue;
|
||||
}
|
||||
|
||||
if (
|
||||
checkDisabledFromStyle(element) &&
|
||||
!attrs["disabled"] &&
|
||||
!attrs["aria-disabled"]
|
||||
) {
|
||||
attrs["disabled"] = true;
|
||||
}
|
||||
|
||||
if (
|
||||
checkRequiredFromStyle(element) &&
|
||||
!attrs["required"] &&
|
||||
|
||||
@@ -17,11 +17,12 @@ from skyvern.webeye.browser_factory import BrowserState
|
||||
from skyvern.webeye.utils.page import SkyvernFrame
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]]
|
||||
CleanupElementTreeFunc = Callable[[Page | Frame, str, list[dict]], Awaitable[list[dict]]]
|
||||
|
||||
RESERVED_ATTRIBUTES = {
|
||||
"accept", # for input file
|
||||
"alt",
|
||||
"shape-description", # for css shape
|
||||
"aria-checked", # for option tag
|
||||
"aria-current",
|
||||
"aria-label",
|
||||
@@ -122,8 +123,8 @@ def json_to_html(element: dict, need_skyvern_attrs: bool = True) -> str:
|
||||
if element.get("purgeable", False):
|
||||
return children_html + option_html
|
||||
|
||||
before_pseudo_text = element.get("beforePseudoText", "")
|
||||
after_pseudo_text = element.get("afterPseudoText", "")
|
||||
before_pseudo_text = element.get("beforePseudoText") or ""
|
||||
after_pseudo_text = element.get("afterPseudoText") or ""
|
||||
|
||||
# Check if the element is self-closing
|
||||
if (
|
||||
@@ -347,7 +348,7 @@ async def scrape_web_unsafe(
|
||||
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=True)
|
||||
|
||||
elements, element_tree = await get_interactable_element_tree(page, scrape_exclude)
|
||||
element_tree = await cleanup_element_tree(url, copy.deepcopy(element_tree))
|
||||
element_tree = await cleanup_element_tree(page, url, copy.deepcopy(element_tree))
|
||||
|
||||
id_to_css_dict, id_to_element_dict, id_to_frame_dict, id_to_element_hash, hash_to_element_ids = build_element_dict(
|
||||
elements
|
||||
@@ -486,7 +487,7 @@ class IncrementalScrapePage:
|
||||
|
||||
self.elements = incremental_elements
|
||||
|
||||
incremental_tree = await cleanup_element_tree(frame.url, copy.deepcopy(incremental_tree))
|
||||
incremental_tree = await cleanup_element_tree(frame, frame.url, copy.deepcopy(incremental_tree))
|
||||
trimmed_element_tree = trim_element_tree(copy.deepcopy(incremental_tree))
|
||||
|
||||
self.element_tree = incremental_tree
|
||||
|
||||
@@ -23,6 +23,7 @@ from skyvern.exceptions import (
|
||||
)
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.webeye.scraper.scraper import IncrementalScrapePage, ScrapedPage, json_to_html, trim_element
|
||||
from skyvern.webeye.utils.page import SkyvernFrame
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -224,10 +225,14 @@ class SkyvernElement:
|
||||
|
||||
disabled_attr: bool | str | None = None
|
||||
aria_disabled_attr: bool | str | None = None
|
||||
style_disabled: bool = False
|
||||
|
||||
try:
|
||||
disabled_attr = await self.get_attr("disabled", dynamic=dynamic)
|
||||
aria_disabled_attr = await self.get_attr("aria-disabled", dynamic=dynamic)
|
||||
skyvern_frame = await SkyvernFrame.create_instance(self.get_frame())
|
||||
style_disabled = await skyvern_frame.get_disabled_from_style(await self.get_element_handler())
|
||||
|
||||
except Exception:
|
||||
# FIXME: maybe it should be considered as "disabled" element if failed to get the attributes?
|
||||
LOG.exception(
|
||||
@@ -250,7 +255,7 @@ class SkyvernElement:
|
||||
if isinstance(aria_disabled_attr, str):
|
||||
aria_disabled = aria_disabled_attr.lower() != "false"
|
||||
|
||||
return disabled or aria_disabled
|
||||
return disabled or aria_disabled or style_disabled
|
||||
|
||||
async def is_selectable(self) -> bool:
|
||||
return self.get_selectable() or self.get_tag_name() in SELECTABLE_ELEMENT
|
||||
|
||||
@@ -162,6 +162,10 @@ class SkyvernFrame:
|
||||
js_script = "(element) => isElementVisible(element) && !isHidden(element)"
|
||||
return await self.frame.evaluate(js_script, element)
|
||||
|
||||
async def get_disabled_from_style(self, element: ElementHandle) -> bool:
|
||||
js_script = "(element) => checkDisabledFromStyle(element)"
|
||||
return await self.frame.evaluate(js_script, element)
|
||||
|
||||
async def scroll_to_top(self, draw_boxes: bool) -> float:
|
||||
"""
|
||||
Scroll to the top of the page and take a screenshot.
|
||||
|
||||
Reference in New Issue
Block a user