optimize and speed up custom selection (#2215)

This commit is contained in:
Shuchang Zheng
2025-04-23 01:44:14 +08:00
committed by GitHub
parent ce6d6c51e0
commit 654ba03e09
3 changed files with 81 additions and 39 deletions

View File

@@ -116,6 +116,7 @@ async def _convert_svg_to_string(
element: Dict,
task: Task | None = None,
step: Step | None = None,
always_drop: bool = False,
) -> None:
if element.get("tagName") != "svg":
return
@@ -123,6 +124,10 @@ async def _convert_svg_to_string(
if element.get("isDropped", False):
return
if always_drop:
_mark_element_as_dropped(element)
return
task_id = task.task_id if task else None
step_id = step.step_id if step else None
element_id = element.get("id", "")
@@ -476,6 +481,8 @@ class AgentFunction:
task: Task | None = None,
step: Step | None = None,
) -> CleanupElementTreeFunc:
MAX_ELEMENT_CNT = 3000
async def cleanup_element_tree_func(frame: Page | Frame, url: str, element_tree: list[dict]) -> list[dict]:
"""
Remove rect and attribute.unique_id from the elements.
@@ -491,10 +498,19 @@ class AgentFunction:
current_frame_index = context.frame_index_map.get(frame, 0)
queue = []
element_cnt = 0
for element in element_tree:
queue.append(element)
while queue:
queue_ele = queue.pop(0)
element_cnt += 1
if element_cnt == MAX_ELEMENT_CNT:
LOG.warning(
f"Element reached max count {MAX_ELEMENT_CNT}, will stop converting svg and css element."
)
element_exceeded = element_cnt > MAX_ELEMENT_CNT
if queue_ele.get("frame_index") != current_frame_index:
new_frame = next(
(k for k, v in context.frame_index_map.items() if v == queue_ele.get("frame_index")), frame
@@ -503,9 +519,9 @@ class AgentFunction:
current_frame_index = queue_ele.get("frame_index", 0)
_remove_rect(queue_ele)
await _convert_svg_to_string(skyvern_frame, queue_ele, task, step)
await _convert_svg_to_string(skyvern_frame, queue_ele, task, step, always_drop=element_exceeded)
if _should_css_shape_convert(element=queue_ele):
if not element_exceeded and _should_css_shape_convert(element=queue_ele):
await _convert_css_shape_to_string(
skyvern_frame=skyvern_frame,
element=queue_ele,

View File

@@ -1383,9 +1383,13 @@ async function buildElementTree(
starter = document.body,
frame,
full_tree = false,
needContext = true,
hoverStylesMap = undefined,
) {
// Generate hover styles map at the start
const hoverStylesMap = getHoverStylesMap();
if (hoverStylesMap === undefined) {
hoverStylesMap = getHoverStylesMap();
}
var elements = [];
var resultArray = [];
@@ -1725,35 +1729,36 @@ async function buildElementTree(
}
let ctxList = [];
try {
ctxList = getContextByLinked(element, ctxList);
} catch (e) {
console.error("failed to get context by linked: ", e);
}
if (needContext) {
try {
ctxList = getContextByLinked(element, ctxList);
} catch (e) {
console.error("failed to get context by linked: ", e);
}
try {
ctxList = getContextByParent(element, ctxList);
} catch (e) {
console.error("failed to get context by parent: ", e);
}
try {
ctxList = getContextByParent(element, ctxList);
} catch (e) {
console.error("failed to get context by parent: ", e);
}
try {
ctxList = getContextByTable(element, ctxList);
} catch (e) {
console.error("failed to get context by table: ", e);
}
const context = ctxList.join(";");
if (context && context.length <= 5000) {
element.context = context;
}
// FIXME: skip <a> for now to prevent navigating to other page by mistake
if (element.tagName !== "a" && checkStringIncludeRequire(context)) {
if (
!element.attributes["required"] &&
!element.attributes["aria-required"]
) {
element.attributes["required"] = true;
try {
ctxList = getContextByTable(element, ctxList);
} catch (e) {
console.error("failed to get context by table: ", e);
}
const context = ctxList.join(";");
if (context && context.length <= 5000) {
element.context = context;
}
// FIXME: skip <a> for now to prevent navigating to other page by mistake
if (element.tagName !== "a" && checkStringIncludeRequire(context)) {
if (
!element.attributes["required"] &&
!element.attributes["aria-required"]
) {
element.attributes["required"] = true;
}
}
}
}
@@ -1761,7 +1766,9 @@ async function buildElementTree(
resultArray = removeOrphanNode(resultArray);
resultArray.forEach((root) => {
trimDuplicatedText(root);
trimDuplicatedContext(root);
if (needContext) {
trimDuplicatedContext(root);
}
});
return [elements, resultArray];
@@ -2211,6 +2218,7 @@ function asyncSleepFor(ms) {
async function addIncrementalNodeToMap(parentNode, childrenNode) {
const maxParsedElement = 3000;
const maxElementToWait = 100;
if ((await window.globalParsedElementCounter.get()) > maxParsedElement) {
console.warn(
"Too many elements parsed, stopping the observer to parse the elements",
@@ -2232,9 +2240,19 @@ async function addIncrementalNodeToMap(parentNode, childrenNode) {
try {
for (const child of childrenNode) {
// sleep for a while until animation ends
await asyncSleepFor(300);
if (
(await window.globalParsedElementCounter.get()) < maxElementToWait
) {
await asyncSleepFor(300);
}
// Pass -1 as frame_index to indicate the frame number is not sensitive in this case
const [_, newNodeTree] = await buildElementTree(child, "", true);
const [_, newNodeTree] = await buildElementTree(
child,
"",
true,
false,
window.globalHoverStylesMap,
);
if (newNodeTree.length > 0) {
newNodesTreeList.push(...newNodeTree);
}
@@ -2352,6 +2370,7 @@ function startGlobalIncrementalObserver(element = null) {
window.globalListnerFlag = true;
window.globalDomDepthMap = new Map();
window.globalOneTimeIncrementElements = [];
window.globalHoverStylesMap = getHoverStylesMap();
window.globalParsedElementCounter = new SafeCounter();
window.globalObserverForDOMIncrement.takeRecords(); // cleanup the older data
window.globalObserverForDOMIncrement.observe(document.body, {

View File

@@ -163,7 +163,7 @@ class SkyvernElement:
return False
async def is_custom_option(self) -> bool:
return self.get_tag_name() == "li" or await self.get_attr("role") == "option"
return self.get_tag_name() == "li" or await self.get_attr("role", mode="static") == "option"
async def is_checkbox(self) -> bool:
tag_name = self.get_tag_name()
@@ -243,9 +243,10 @@ class SkyvernElement:
aria_disabled_attr: bool | str | None = None
style_disabled: bool = False
mode: typing.Literal["auto", "dynamic"] = "dynamic" if dynamic else "auto"
try:
disabled_attr = await self.get_attr("disabled", dynamic=dynamic)
aria_disabled_attr = await self.get_attr("aria-disabled", dynamic=dynamic)
disabled_attr = await self.get_attr("disabled", mode=mode)
aria_disabled_attr = await self.get_attr("aria-disabled", mode=mode)
skyvern_frame = await SkyvernFrame.create_instance(self.get_frame())
style_disabled = await skyvern_frame.get_disabled_from_style(await self.get_element_handler())
@@ -511,12 +512,18 @@ class SkyvernElement:
async def get_attr(
self,
attr_name: str,
dynamic: bool = False,
mode: typing.Literal["auto", "dynamic", "static"] = "auto",
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
) -> typing.Any:
if not dynamic:
"""
mode:
auto: use value from the self.get_attributes() first. if empty, then try to get the value from the locator.get_attribute()
dynamic: always use locator.get_attribute()
static: always use self.get_attributes()
"""
if mode != "dynamic":
attr = self.get_attributes().get(attr_name)
if attr is not None:
if attr is not None or mode == "static":
return attr
return await self.locator.get_attribute(attr_name, timeout=timeout)