From 3c612968ce2ac2f8da1e29bac93a6f5541378edf Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Fri, 4 Apr 2025 22:33:52 -0400 Subject: [PATCH] trim svg elements when prompt exceeds context window (#2106) --- poetry.lock | 2 +- pyproject.toml | 1 + skyvern/forge/agent.py | 22 ++++++------ skyvern/forge/agent_functions.py | 4 ++- skyvern/services/task_v2_service.py | 6 ++-- skyvern/utils/prompt_engine.py | 47 +++++++++++++++++++++++++ skyvern/utils/token_counter.py | 5 +++ skyvern/webeye/scraper/scraper.py | 53 +++++++++++++++++++++++++++++ 8 files changed, 126 insertions(+), 14 deletions(-) create mode 100644 skyvern/utils/prompt_engine.py create mode 100644 skyvern/utils/token_counter.py diff --git a/poetry.lock b/poetry.lock index 3abf8e54..866e84f0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6521,4 +6521,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11,<3.12" -content-hash = "b43cb55e0c18ac83f0e32444132fd7618ef5b8355b0a90dbed55599d068c2892" +content-hash = "84b211a2b313b852996823fc4105d809b990e34cecd400c61d541561c010afdf" diff --git a/pyproject.toml b/pyproject.toml index 32269113..60ebc34b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ json-repair = "^0.34.0" pypdf = "^5.1.0" fastmcp = "^0.4.1" psutil = ">=7.0.0" +tiktoken = ">=0.9.0" [tool.poetry.group.dev.dependencies] isort = "^5.13.2" diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 8f9bcf4c..aa537019 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -68,6 +68,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus +from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.webeye.actions.actions import ( Action, ActionStatus, @@ -1196,11 +1197,12 @@ class ForgeAgent: ) scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False) - verification_prompt = prompt_engine.load_prompt( - "check-user-goal", + verification_prompt = load_prompt_with_elements( + scraped_page=scraped_page_refreshed, + prompt_engine=prompt_engine, + template_name="check-user-goal", navigation_goal=task.navigation_goal, navigation_payload=task.navigation_payload, - elements=scraped_page_refreshed.build_element_tree(ElementTreeFormat.HTML), complete_criterion=task.complete_criterion, ) @@ -1432,7 +1434,7 @@ class ForgeAgent: task, step, browser_state, - element_tree_in_prompt, + scraped_page, verification_code_check=bool(task.totp_verification_url or task.totp_identifier), expire_verification_code=True, ) @@ -1470,7 +1472,7 @@ class ForgeAgent: task: Task, step: Step, browser_state: BrowserState, - element_tree_in_prompt: str, + scraped_page: ScrapedPage, verification_code_check: bool = False, expire_verification_code: bool = False, ) -> str: @@ -1525,13 +1527,14 @@ class ForgeAgent: raise UnsupportedTaskType(task_type=task_type) context = skyvern_context.ensure_context() - return prompt_engine.load_prompt( - template=template, + return load_prompt_with_elements( + scraped_page=scraped_page, + prompt_engine=prompt_engine, + template_name=template, navigation_goal=navigation_goal, navigation_payload_str=json.dumps(final_navigation_payload), starting_url=starting_url, current_url=current_url, - elements=element_tree_in_prompt, data_extraction_goal=task.data_extraction_goal, action_history=actions_and_results_str, error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None), @@ -2300,12 +2303,11 @@ class ForgeAgent: current_context = skyvern_context.ensure_context() current_context.totp_codes[task.task_id] = verification_code - element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML) extract_action_prompt = await self._build_extract_action_prompt( task, step, browser_state, - element_tree_in_prompt, + scraped_page, verification_code_check=False, expire_verification_code=True, ) diff --git a/skyvern/forge/agent_functions.py b/skyvern/forge/agent_functions.py index 5314ad6c..58bbc471 100644 --- a/skyvern/forge/agent_functions.py +++ b/skyvern/forge/agent_functions.py @@ -139,7 +139,9 @@ async def _convert_svg_to_string( skyvern_element = SkyvernElement(locator=locater, frame=skyvern_frame.get_frame(), static_element=element) - _, blocked = await skyvern_frame.get_blocking_element_id(await skyvern_element.get_element_handler()) + _, blocked = await skyvern_frame.get_blocking_element_id( + await skyvern_element.get_element_handler(timeout=1000) + ) if not skyvern_element.is_interactable() and blocked: _mark_element_as_dropped(element) return diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index 5090e7c2..d8f21106 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -53,6 +53,7 @@ from skyvern.forge.sdk.workflow.models.yaml import ( WorkflowDefinitionYAML, ) from skyvern.schemas.runs import ProxyLocation, RunType +from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website from skyvern.webeye.utils.page import SkyvernFrame @@ -462,10 +463,11 @@ async def run_task_v2_helper( continue current_url = current_url if current_url else str(await SkyvernFrame.get_url(frame=page) if page else url) - task_v2_prompt = prompt_engine.load_prompt( + task_v2_prompt = load_prompt_with_elements( + scraped_page, + prompt_engine, "task_v2", current_url=current_url, - elements=element_tree_in_prompt, user_goal=user_prompt, task_history=task_history, local_datetime=datetime.now(context.tz_info).isoformat(), diff --git a/skyvern/utils/prompt_engine.py b/skyvern/utils/prompt_engine.py new file mode 100644 index 00000000..f0038025 --- /dev/null +++ b/skyvern/utils/prompt_engine.py @@ -0,0 +1,47 @@ +from typing import Any + +import structlog + +from skyvern.forge.sdk.prompting import PromptEngine +from skyvern.utils.token_counter import count_tokens +from skyvern.webeye.scraper.scraper import ScrapedPage + +DEFAULT_MAX_TOKENS = 100000 +LOG = structlog.get_logger() + + +def load_prompt_with_elements( + scraped_page: ScrapedPage, + prompt_engine: PromptEngine, + template_name: str, + **kwargs: Any, +) -> str: + prompt = prompt_engine.load_prompt(template_name, elements=scraped_page.build_element_tree(), **kwargs) + token_count = count_tokens(prompt) + if token_count > DEFAULT_MAX_TOKENS: + # get rid of all the secondary elements like SVG, etc + economy_elements_tree = scraped_page.build_economy_elements_tree() + prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs) + economy_token_count = count_tokens(prompt) + LOG.warning( + "Prompt is longer than the max tokens. Going to use the economy elements tree.", + template_name=template_name, + token_count=token_count, + economy_token_count=economy_token_count, + max_tokens=DEFAULT_MAX_TOKENS, + ) + if economy_token_count > DEFAULT_MAX_TOKENS: + # !!! HACK alert + # dump the last 1/3 of the html context and keep the first 2/3 of the html context + economy_elements_tree_dumped = scraped_page.build_economy_elements_tree(percent_to_keep=2 / 3) + prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree_dumped, **kwargs) + token_count_after_dump = count_tokens(prompt) + LOG.warning( + "Prompt is still longer than the max tokens. Will only keep the first 2/3 of the html context.", + template_name=template_name, + token_count=token_count, + economy_token_count=economy_token_count, + token_count_after_dump=token_count_after_dump, + max_tokens=DEFAULT_MAX_TOKENS, + ) + return prompt diff --git a/skyvern/utils/token_counter.py b/skyvern/utils/token_counter.py new file mode 100644 index 00000000..4e3aaa5d --- /dev/null +++ b/skyvern/utils/token_counter.py @@ -0,0 +1,5 @@ +import tiktoken + + +def count_tokens(text: str) -> int: + return len(tiktoken.encoding_for_model("gpt-4o").encode(text)) diff --git a/skyvern/webeye/scraper/scraper.py b/skyvern/webeye/scraper/scraper.py index eed420e2..47ce3267 100644 --- a/skyvern/webeye/scraper/scraper.py +++ b/skyvern/webeye/scraper/scraper.py @@ -229,6 +229,7 @@ class ScrapedPage(BaseModel): hash_to_element_ids: dict[str, list[str]] element_tree: list[dict] element_tree_trimmed: list[dict] + economy_element_tree: list[dict] | None = None screenshots: list[bytes] url: str html: str @@ -268,6 +269,58 @@ class ScrapedPage(BaseModel): raise UnknownElementTreeFormat(fmt=fmt) + def build_economy_elements_tree( + self, + fmt: ElementTreeFormat = ElementTreeFormat.HTML, + html_need_skyvern_attrs: bool = True, + percent_to_keep: float = 1, + ) -> str: + """ + Economy elements tree doesn't include secondary elements like SVG, etc + """ + if not self.economy_element_tree: + economy_elements = [] + copied_element_tree_trimmed = copy.deepcopy(self.element_tree_trimmed) + + # Process each root element + for root_element in copied_element_tree_trimmed: + processed_element = self._process_element_for_economy_tree(root_element) + if processed_element: + economy_elements.append(processed_element) + + self.economy_element_tree = economy_elements + + final_element_tree = self.economy_element_tree[: int(len(self.economy_element_tree) * percent_to_keep)] + + if fmt == ElementTreeFormat.JSON: + return json.dumps(final_element_tree) + + if fmt == ElementTreeFormat.HTML: + return "".join( + json_to_html(element, need_skyvern_attrs=html_need_skyvern_attrs) for element in final_element_tree + ) + + raise UnknownElementTreeFormat(fmt=fmt) + + def _process_element_for_economy_tree(self, element: dict) -> dict | None: + """ + Helper method to process an element for the economy tree using BFS. + Removes SVG elements and their children. + """ + # Skip SVG elements entirely + if element.get("tagName", "").lower() == "svg": + return None + + # Process children using BFS + if "children" in element: + new_children = [] + for child in element["children"]: + processed_child = self._process_element_for_economy_tree(child) + if processed_child: + new_children.append(processed_child) + element["children"] = new_children + return element + async def refresh(self, draw_boxes: bool = True) -> Self: refreshed_page = await scrape_website( browser_state=self._browser_state,