convert css shape to string (#1092)

This commit is contained in:
LawyZheng
2024-10-31 00:12:13 +08:00
committed by GitHub
parent 01fbdeece4
commit 8762865a0b
7 changed files with 198 additions and 16 deletions

View File

@@ -4,7 +4,7 @@ import hashlib
from typing import Dict, List
import structlog
from playwright.async_api import Page
from playwright.async_api import Frame, Page
from skyvern.config import settings
from skyvern.constants import SKYVERN_ID_ATTR
@@ -19,8 +19,8 @@ from skyvern.webeye.scraper.scraper import ELEMENT_NODE_ATTRIBUTES, CleanupEleme
LOG = structlog.get_logger()
USELESS_SVG_ATTRIBUTE = [SKYVERN_ID_ATTR, "id", "aria-describedby"]
SVG_RETRY_ATTEMPT = 3
USELESS_SHAPE_ATTRIBUTE = [SKYVERN_ID_ATTR, "id", "aria-describedby"]
SHAPE_CONVERTION_RETRY_ATTEMPT = 3
def _remove_rect(element: dict) -> None:
@@ -28,10 +28,45 @@ def _remove_rect(element: dict) -> None:
del element["rect"]
def _should_css_shape_convert(element: Dict) -> bool:
if "id" not in element:
return False
tag_name = element.get("tagName")
if tag_name not in ["a", "span", "i"]:
return False
# if <span> and <i> without any text in the element, we try to convert the shape
if tag_name in ["span", "i"] and not element.get("text"):
return True
# if <a>, it should be no text, no children, no href/target attribute
if tag_name == "a":
attributes = element.get("attributes", {})
if element.get("text"):
return False
if len(element.get("children", [])) > 0:
return False
if "href" in attributes:
return False
if "target" in attributes:
return False
return True
return False
def _get_svg_cache_key(hash: str) -> str:
return f"skyvern:svg:{hash}"
def _get_shape_cache_key(hash: str) -> str:
return f"skyvern:shape:{hash}"
def _remove_skyvern_attributes(element: Dict) -> Dict:
"""
To get the original HTML element without skyvern attributes
@@ -44,7 +79,7 @@ def _remove_skyvern_attributes(element: Dict) -> Dict:
if "attributes" in element_copied:
attributes: dict = copy.deepcopy(element_copied.get("attributes", {}))
for key in attributes.keys():
if key in USELESS_SVG_ATTRIBUTE:
if key in USELESS_SHAPE_ATTRIBUTE:
del element_copied["attributes"][key]
children: List[Dict] | None = element_copied.get("children", None)
@@ -80,6 +115,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
except Exception:
LOG.warning(
"Failed to loaded SVG cache",
task_id=task.task_id,
step_id=step.step_id,
exc_info=True,
key=svg_key,
)
@@ -92,6 +129,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
LOG.warning(
"SVG element is too large to convert, going to drop the svg element.",
element_id=element_id,
task_id=task.task_id,
step_id=step.step_id,
length=len(svg_html),
)
del element["children"]
@@ -101,7 +140,7 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
LOG.debug("call LLM to convert SVG to string shape", element_id=element_id)
svg_convert_prompt = prompt_engine.load_prompt("svg-convert", svg_element=svg_html)
for retry in range(SVG_RETRY_ATTEMPT):
for retry in range(SHAPE_CONVERTION_RETRY_ATTEMPT):
try:
json_response = await app.SECONDARY_LLM_API_HANDLER(prompt=svg_convert_prompt, step=step)
svg_shape = json_response.get("shape", "")
@@ -113,6 +152,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
except Exception:
LOG.exception(
"Failed to convert SVG to string shape by secondary llm. Will retry if haven't met the max try attempt after 3s.",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
retry=retry,
)
@@ -126,6 +167,101 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
return
async def _convert_css_shape_to_string(
task: Task, step: Step, organization: Organization | None, frame: Page | Frame, element: Dict
) -> None:
element_id: str = element.get("id", "")
shape_element = _remove_skyvern_attributes(element)
svg_html = json_to_html(shape_element)
hash_object = hashlib.sha256()
hash_object.update(svg_html.encode("utf-8"))
shape_hash = hash_object.hexdigest()
shape_key = _get_shape_cache_key(shape_hash)
css_shape: str | None = None
try:
css_shape = await app.CACHE.get(shape_key)
except Exception:
LOG.warning(
"Failed to loaded CSS shape cache",
task_id=task.task_id,
step_id=step.step_id,
exc_info=True,
key=shape_key,
)
if css_shape:
LOG.debug("CSS shape loaded from cache", element_id=element_id, shape=css_shape)
else:
# FIXME: support element in iframe
locater = frame.locator(f'[{SKYVERN_ID_ATTR}="{element_id}"]')
if await locater.count() == 0:
LOG.info(
"No locater found to convert css shape",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
)
return None
if await locater.count() > 1:
LOG.info(
"multiple locaters found to convert css shape",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
)
return None
try:
LOG.debug("call LLM to convert css shape to string shape", element_id=element_id)
screenshot = await locater.screenshot(timeout=settings.BROWSER_SCREENSHOT_TIMEOUT_MS)
prompt = prompt_engine.load_prompt("css-shape-convert")
for retry in range(SHAPE_CONVERTION_RETRY_ATTEMPT):
try:
json_response = await app.SECONDARY_LLM_API_HANDLER(
prompt=prompt, screenshots=[screenshot], step=step
)
css_shape = json_response.get("shape", "")
if not css_shape:
raise Exception("Empty css shape replied by secondary llm")
LOG.info("CSS Shape converted by LLM", element_id=element_id, shape=css_shape)
await app.CACHE.set(shape_key, css_shape)
break
except Exception:
LOG.exception(
"Failed to convert css shape to string shape by secondary llm. Will retry if haven't met the max try attempt after 3s.",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
retry=retry,
)
await asyncio.sleep(3)
else:
LOG.info(
"Max css shape convertion retry, going to abort the convertion.",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
)
return None
except Exception:
LOG.exception(
"Failed to convert css shape to string shape by LLM",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
)
return None
if "attributes" not in element:
element["attributes"] = dict()
element["attributes"]["shape-description"] = css_shape
return None
class AgentFunction:
async def validate_step_execution(
self,
@@ -181,7 +317,7 @@ class AgentFunction:
step: Step,
organization: Organization | None = None,
) -> CleanupElementTreeFunc:
async def cleanup_element_tree_func(url: str, element_tree: list[dict]) -> list[dict]:
async def cleanup_element_tree_func(frame: Page | Frame, url: str, element_tree: list[dict]) -> list[dict]:
"""
Remove rect and attribute.unique_id from the elements.
The reason we're doing it is to
@@ -197,6 +333,16 @@ class AgentFunction:
queue_ele = queue.pop(0)
_remove_rect(queue_ele)
await _convert_svg_to_string(task, step, organization, queue_ele)
if _should_css_shape_convert(element=queue_ele):
await _convert_css_shape_to_string(
task=task,
step=step,
organization=organization,
frame=frame,
element=queue_ele,
)
# TODO: we can come back to test removing the unique_id
# from element attributes to make sure this won't increase hallucination
# _remove_unique_id(queue_ele)

View File

@@ -0,0 +1,8 @@
You are given a screenshot of an HTML element. You need to figure out what its shape means.
MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc.
Reply in JSON format with the following keys:
{
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
"shape": string, // A short description of the shape of element and its meaning
}

View File

@@ -9,7 +9,7 @@ from typing import Any, Awaitable, Callable, List
import structlog
from deprecation import deprecated
from playwright.async_api import FileChooser, Locator, Page, TimeoutError
from playwright.async_api import FileChooser, Frame, Locator, Page, TimeoutError
from pydantic import BaseModel
from skyvern.constants import REPO_ROOT_DIR, SKYVERN_ID_ATTR
@@ -165,8 +165,10 @@ def remove_exist_elements(element_tree: list[dict], check_exist: CheckExistIDFun
def clean_and_remove_element_tree_factory(
task: Task, step: Step, check_exist_funcs: list[CheckExistIDFunc]
) -> CleanupElementTreeFunc:
async def helper_func(url: str, element_tree: list[dict]) -> list[dict]:
element_tree = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(url, element_tree)
async def helper_func(frame: Page | Frame, url: str, element_tree: list[dict]) -> list[dict]:
element_tree = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(
frame, url, element_tree
)
for check_exist in check_exist_funcs:
element_tree = remove_exist_elements(element_tree=element_tree, check_exist=check_exist)
return element_tree
@@ -1270,7 +1272,7 @@ async def choose_auto_completion_dropdown(
if len(confirmed_preserved_list) > 0:
confirmed_preserved_list = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(
skyvern_frame.get_frame().url, copy.deepcopy(confirmed_preserved_list)
skyvern_frame.get_frame(), skyvern_frame.get_frame().url, copy.deepcopy(confirmed_preserved_list)
)
confirmed_preserved_list = trim_element_tree(copy.deepcopy(confirmed_preserved_list))

View File

@@ -686,6 +686,14 @@ const checkRequiredFromStyle = (element) => {
return element.className.toLowerCase().includes("require");
};
function checkDisabledFromStyle(element) {
const className = element.className.toString().toLowerCase();
if (className.includes("react-datepicker__day--disabled")) {
return true;
}
return false;
}
function getElementContext(element) {
// dfs to collect the non unique_id context
let fullContext = new Array();
@@ -872,6 +880,14 @@ function buildElementObject(frame, element, interactable, purgeable = false) {
attrs[attr.name] = attrValue;
}
if (
checkDisabledFromStyle(element) &&
!attrs["disabled"] &&
!attrs["aria-disabled"]
) {
attrs["disabled"] = true;
}
if (
checkRequiredFromStyle(element) &&
!attrs["required"] &&

View File

@@ -17,11 +17,12 @@ from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.utils.page import SkyvernFrame
LOG = structlog.get_logger()
CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]]
CleanupElementTreeFunc = Callable[[Page | Frame, str, list[dict]], Awaitable[list[dict]]]
RESERVED_ATTRIBUTES = {
"accept", # for input file
"alt",
"shape-description", # for css shape
"aria-checked", # for option tag
"aria-current",
"aria-label",
@@ -122,8 +123,8 @@ def json_to_html(element: dict, need_skyvern_attrs: bool = True) -> str:
if element.get("purgeable", False):
return children_html + option_html
before_pseudo_text = element.get("beforePseudoText", "")
after_pseudo_text = element.get("afterPseudoText", "")
before_pseudo_text = element.get("beforePseudoText") or ""
after_pseudo_text = element.get("afterPseudoText") or ""
# Check if the element is self-closing
if (
@@ -347,7 +348,7 @@ async def scrape_web_unsafe(
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=True)
elements, element_tree = await get_interactable_element_tree(page, scrape_exclude)
element_tree = await cleanup_element_tree(url, copy.deepcopy(element_tree))
element_tree = await cleanup_element_tree(page, url, copy.deepcopy(element_tree))
id_to_css_dict, id_to_element_dict, id_to_frame_dict, id_to_element_hash, hash_to_element_ids = build_element_dict(
elements
@@ -486,7 +487,7 @@ class IncrementalScrapePage:
self.elements = incremental_elements
incremental_tree = await cleanup_element_tree(frame.url, copy.deepcopy(incremental_tree))
incremental_tree = await cleanup_element_tree(frame, frame.url, copy.deepcopy(incremental_tree))
trimmed_element_tree = trim_element_tree(copy.deepcopy(incremental_tree))
self.element_tree = incremental_tree

View File

@@ -23,6 +23,7 @@ from skyvern.exceptions import (
)
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.scraper.scraper import IncrementalScrapePage, ScrapedPage, json_to_html, trim_element
from skyvern.webeye.utils.page import SkyvernFrame
LOG = structlog.get_logger()
@@ -224,10 +225,14 @@ class SkyvernElement:
disabled_attr: bool | str | None = None
aria_disabled_attr: bool | str | None = None
style_disabled: bool = False
try:
disabled_attr = await self.get_attr("disabled", dynamic=dynamic)
aria_disabled_attr = await self.get_attr("aria-disabled", dynamic=dynamic)
skyvern_frame = await SkyvernFrame.create_instance(self.get_frame())
style_disabled = await skyvern_frame.get_disabled_from_style(await self.get_element_handler())
except Exception:
# FIXME: maybe it should be considered as "disabled" element if failed to get the attributes?
LOG.exception(
@@ -250,7 +255,7 @@ class SkyvernElement:
if isinstance(aria_disabled_attr, str):
aria_disabled = aria_disabled_attr.lower() != "false"
return disabled or aria_disabled
return disabled or aria_disabled or style_disabled
async def is_selectable(self) -> bool:
return self.get_selectable() or self.get_tag_name() in SELECTABLE_ELEMENT

View File

@@ -162,6 +162,10 @@ class SkyvernFrame:
js_script = "(element) => isElementVisible(element) && !isHidden(element)"
return await self.frame.evaluate(js_script, element)
async def get_disabled_from_style(self, element: ElementHandle) -> bool:
js_script = "(element) => checkDisabledFromStyle(element)"
return await self.frame.evaluate(js_script, element)
async def scroll_to_top(self, draw_boxes: bool) -> float:
"""
Scroll to the top of the page and take a screenshot.