From 654ba03e097cee47acfd59d37ade876a52e1188c Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 23 Apr 2025 01:44:14 +0800 Subject: [PATCH] optimize and speed up custom selection (#2215) --- skyvern/forge/agent_functions.py | 20 +++++++- skyvern/webeye/scraper/domUtils.js | 81 ++++++++++++++++++------------ skyvern/webeye/utils/dom.py | 19 ++++--- 3 files changed, 81 insertions(+), 39 deletions(-) diff --git a/skyvern/forge/agent_functions.py b/skyvern/forge/agent_functions.py index 160c089e..20cd4a13 100644 --- a/skyvern/forge/agent_functions.py +++ b/skyvern/forge/agent_functions.py @@ -116,6 +116,7 @@ async def _convert_svg_to_string( element: Dict, task: Task | None = None, step: Step | None = None, + always_drop: bool = False, ) -> None: if element.get("tagName") != "svg": return @@ -123,6 +124,10 @@ async def _convert_svg_to_string( if element.get("isDropped", False): return + if always_drop: + _mark_element_as_dropped(element) + return + task_id = task.task_id if task else None step_id = step.step_id if step else None element_id = element.get("id", "") @@ -476,6 +481,8 @@ class AgentFunction: task: Task | None = None, step: Step | None = None, ) -> CleanupElementTreeFunc: + MAX_ELEMENT_CNT = 3000 + async def cleanup_element_tree_func(frame: Page | Frame, url: str, element_tree: list[dict]) -> list[dict]: """ Remove rect and attribute.unique_id from the elements. @@ -491,10 +498,19 @@ class AgentFunction: current_frame_index = context.frame_index_map.get(frame, 0) queue = [] + element_cnt = 0 for element in element_tree: queue.append(element) while queue: queue_ele = queue.pop(0) + + element_cnt += 1 + if element_cnt == MAX_ELEMENT_CNT: + LOG.warning( + f"Element reached max count {MAX_ELEMENT_CNT}, will stop converting svg and css element." + ) + element_exceeded = element_cnt > MAX_ELEMENT_CNT + if queue_ele.get("frame_index") != current_frame_index: new_frame = next( (k for k, v in context.frame_index_map.items() if v == queue_ele.get("frame_index")), frame @@ -503,9 +519,9 @@ class AgentFunction: current_frame_index = queue_ele.get("frame_index", 0) _remove_rect(queue_ele) - await _convert_svg_to_string(skyvern_frame, queue_ele, task, step) + await _convert_svg_to_string(skyvern_frame, queue_ele, task, step, always_drop=element_exceeded) - if _should_css_shape_convert(element=queue_ele): + if not element_exceeded and _should_css_shape_convert(element=queue_ele): await _convert_css_shape_to_string( skyvern_frame=skyvern_frame, element=queue_ele, diff --git a/skyvern/webeye/scraper/domUtils.js b/skyvern/webeye/scraper/domUtils.js index 8fce6652..83d7d584 100644 --- a/skyvern/webeye/scraper/domUtils.js +++ b/skyvern/webeye/scraper/domUtils.js @@ -1383,9 +1383,13 @@ async function buildElementTree( starter = document.body, frame, full_tree = false, + needContext = true, + hoverStylesMap = undefined, ) { // Generate hover styles map at the start - const hoverStylesMap = getHoverStylesMap(); + if (hoverStylesMap === undefined) { + hoverStylesMap = getHoverStylesMap(); + } var elements = []; var resultArray = []; @@ -1725,35 +1729,36 @@ async function buildElementTree( } let ctxList = []; - try { - ctxList = getContextByLinked(element, ctxList); - } catch (e) { - console.error("failed to get context by linked: ", e); - } + if (needContext) { + try { + ctxList = getContextByLinked(element, ctxList); + } catch (e) { + console.error("failed to get context by linked: ", e); + } - try { - ctxList = getContextByParent(element, ctxList); - } catch (e) { - console.error("failed to get context by parent: ", e); - } + try { + ctxList = getContextByParent(element, ctxList); + } catch (e) { + console.error("failed to get context by parent: ", e); + } - try { - ctxList = getContextByTable(element, ctxList); - } catch (e) { - console.error("failed to get context by table: ", e); - } - const context = ctxList.join(";"); - if (context && context.length <= 5000) { - element.context = context; - } - - // FIXME: skip for now to prevent navigating to other page by mistake - if (element.tagName !== "a" && checkStringIncludeRequire(context)) { - if ( - !element.attributes["required"] && - !element.attributes["aria-required"] - ) { - element.attributes["required"] = true; + try { + ctxList = getContextByTable(element, ctxList); + } catch (e) { + console.error("failed to get context by table: ", e); + } + const context = ctxList.join(";"); + if (context && context.length <= 5000) { + element.context = context; + } + // FIXME: skip for now to prevent navigating to other page by mistake + if (element.tagName !== "a" && checkStringIncludeRequire(context)) { + if ( + !element.attributes["required"] && + !element.attributes["aria-required"] + ) { + element.attributes["required"] = true; + } } } } @@ -1761,7 +1766,9 @@ async function buildElementTree( resultArray = removeOrphanNode(resultArray); resultArray.forEach((root) => { trimDuplicatedText(root); - trimDuplicatedContext(root); + if (needContext) { + trimDuplicatedContext(root); + } }); return [elements, resultArray]; @@ -2211,6 +2218,7 @@ function asyncSleepFor(ms) { async function addIncrementalNodeToMap(parentNode, childrenNode) { const maxParsedElement = 3000; + const maxElementToWait = 100; if ((await window.globalParsedElementCounter.get()) > maxParsedElement) { console.warn( "Too many elements parsed, stopping the observer to parse the elements", @@ -2232,9 +2240,19 @@ async function addIncrementalNodeToMap(parentNode, childrenNode) { try { for (const child of childrenNode) { // sleep for a while until animation ends - await asyncSleepFor(300); + if ( + (await window.globalParsedElementCounter.get()) < maxElementToWait + ) { + await asyncSleepFor(300); + } // Pass -1 as frame_index to indicate the frame number is not sensitive in this case - const [_, newNodeTree] = await buildElementTree(child, "", true); + const [_, newNodeTree] = await buildElementTree( + child, + "", + true, + false, + window.globalHoverStylesMap, + ); if (newNodeTree.length > 0) { newNodesTreeList.push(...newNodeTree); } @@ -2352,6 +2370,7 @@ function startGlobalIncrementalObserver(element = null) { window.globalListnerFlag = true; window.globalDomDepthMap = new Map(); window.globalOneTimeIncrementElements = []; + window.globalHoverStylesMap = getHoverStylesMap(); window.globalParsedElementCounter = new SafeCounter(); window.globalObserverForDOMIncrement.takeRecords(); // cleanup the older data window.globalObserverForDOMIncrement.observe(document.body, { diff --git a/skyvern/webeye/utils/dom.py b/skyvern/webeye/utils/dom.py index 5e6282d8..2cf2309d 100644 --- a/skyvern/webeye/utils/dom.py +++ b/skyvern/webeye/utils/dom.py @@ -163,7 +163,7 @@ class SkyvernElement: return False async def is_custom_option(self) -> bool: - return self.get_tag_name() == "li" or await self.get_attr("role") == "option" + return self.get_tag_name() == "li" or await self.get_attr("role", mode="static") == "option" async def is_checkbox(self) -> bool: tag_name = self.get_tag_name() @@ -243,9 +243,10 @@ class SkyvernElement: aria_disabled_attr: bool | str | None = None style_disabled: bool = False + mode: typing.Literal["auto", "dynamic"] = "dynamic" if dynamic else "auto" try: - disabled_attr = await self.get_attr("disabled", dynamic=dynamic) - aria_disabled_attr = await self.get_attr("aria-disabled", dynamic=dynamic) + disabled_attr = await self.get_attr("disabled", mode=mode) + aria_disabled_attr = await self.get_attr("aria-disabled", mode=mode) skyvern_frame = await SkyvernFrame.create_instance(self.get_frame()) style_disabled = await skyvern_frame.get_disabled_from_style(await self.get_element_handler()) @@ -511,12 +512,18 @@ class SkyvernElement: async def get_attr( self, attr_name: str, - dynamic: bool = False, + mode: typing.Literal["auto", "dynamic", "static"] = "auto", timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS, ) -> typing.Any: - if not dynamic: + """ + mode: + auto: use value from the self.get_attributes() first. if empty, then try to get the value from the locator.get_attribute() + dynamic: always use locator.get_attribute() + static: always use self.get_attributes() + """ + if mode != "dynamic": attr = self.get_attributes().get(attr_name) - if attr is not None: + if attr is not None or mode == "static": return attr return await self.locator.get_attribute(attr_name, timeout=timeout)