Parallelize SVG conversions (#2281)
This commit is contained in:
@@ -111,22 +111,23 @@ def _mark_element_as_dropped(element: dict) -> None:
|
||||
element["isDropped"] = True
|
||||
|
||||
|
||||
async def _convert_svg_to_string(
|
||||
async def _check_svg_eligibility(
|
||||
skyvern_frame: SkyvernFrame,
|
||||
element: Dict,
|
||||
task: Task | None = None,
|
||||
step: Step | None = None,
|
||||
always_drop: bool = False,
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""Check if an SVG element is eligible for conversion."""
|
||||
if element.get("tagName") != "svg":
|
||||
return
|
||||
return False
|
||||
|
||||
if element.get("isDropped", False):
|
||||
return
|
||||
return False
|
||||
|
||||
if always_drop:
|
||||
_mark_element_as_dropped(element)
|
||||
return
|
||||
return False
|
||||
|
||||
task_id = task.task_id if task else None
|
||||
step_id = step.step_id if step else None
|
||||
@@ -136,11 +137,11 @@ async def _convert_svg_to_string(
|
||||
locater = skyvern_frame.get_frame().locator(f'[{SKYVERN_ID_ATTR}="{element_id}"]')
|
||||
if await locater.count() == 0:
|
||||
_mark_element_as_dropped(element)
|
||||
return
|
||||
return False
|
||||
|
||||
if not await locater.is_visible(timeout=settings.BROWSER_ACTION_TIMEOUT_MS):
|
||||
_mark_element_as_dropped(element)
|
||||
return
|
||||
return False
|
||||
|
||||
skyvern_element = SkyvernElement(locator=locater, frame=skyvern_frame.get_frame(), static_element=element)
|
||||
|
||||
@@ -149,7 +150,7 @@ async def _convert_svg_to_string(
|
||||
)
|
||||
if not skyvern_element.is_interactable() and blocked:
|
||||
_mark_element_as_dropped(element)
|
||||
return
|
||||
return False
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to get the blocking element for the svg, going to continue parsing the svg",
|
||||
@@ -158,6 +159,19 @@ async def _convert_svg_to_string(
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _convert_svg_to_string(
|
||||
element: Dict,
|
||||
task: Task | None = None,
|
||||
step: Step | None = None,
|
||||
) -> None:
|
||||
"""Convert an SVG element to a string description. Assumes element has already passed eligibility checks."""
|
||||
task_id = task.task_id if task else None
|
||||
step_id = step.step_id if step else None
|
||||
element_id = element.get("id", "")
|
||||
|
||||
svg_element = _remove_skyvern_attributes(element)
|
||||
svg_html = json_to_html(svg_element)
|
||||
hash_object = hashlib.sha256()
|
||||
@@ -499,8 +513,11 @@ class AgentFunction:
|
||||
|
||||
queue = []
|
||||
element_cnt = 0
|
||||
eligible_svgs = [] # List to store eligible SVGs and their frames
|
||||
|
||||
for element in element_tree:
|
||||
queue.append(element)
|
||||
|
||||
while queue:
|
||||
queue_ele = queue.pop(0)
|
||||
|
||||
@@ -519,7 +536,10 @@ class AgentFunction:
|
||||
current_frame_index = queue_ele.get("frame_index", 0)
|
||||
|
||||
_remove_rect(queue_ele)
|
||||
await _convert_svg_to_string(skyvern_frame, queue_ele, task, step, always_drop=element_exceeded)
|
||||
|
||||
# Check SVG eligibility and store for later conversion
|
||||
if await _check_svg_eligibility(skyvern_frame, queue_ele, task, step, always_drop=element_exceeded):
|
||||
eligible_svgs.append((queue_ele, skyvern_frame))
|
||||
|
||||
if not element_exceeded and _should_css_shape_convert(element=queue_ele):
|
||||
await _convert_css_shape_to_string(
|
||||
@@ -534,6 +554,11 @@ class AgentFunction:
|
||||
# _remove_unique_id(queue_ele)
|
||||
if "children" in queue_ele:
|
||||
queue.extend(queue_ele["children"])
|
||||
|
||||
# Convert all eligible SVGs in parallel
|
||||
if eligible_svgs:
|
||||
await asyncio.gather(*[_convert_svg_to_string(element, task, step) for element, frame in eligible_svgs])
|
||||
|
||||
return element_tree
|
||||
|
||||
return cleanup_element_tree_func
|
||||
|
||||
Reference in New Issue
Block a user