only take up to 1 screenshot if the html too big (#2108)

This commit is contained in:
Shuchang Zheng
2025-04-05 23:33:34 -04:00
committed by GitHub
parent 3c612968ce
commit a72fcadd9a
5 changed files with 64 additions and 38 deletions

View File

@@ -32,3 +32,4 @@ class ScrapeType(StrEnum):
SCRAPE_TYPE_ORDER = [ScrapeType.NORMAL, ScrapeType.NORMAL, ScrapeType.RELOAD]
DEFAULT_MAX_TOKENS = 100000

View File

@@ -55,7 +55,7 @@ from skyvern.forge.sdk.workflow.models.yaml import (
from skyvern.schemas.runs import ProxyLocation, RunType
from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website
from skyvern.webeye.utils.page import SkyvernFrame
LOG = structlog.get_logger()
@@ -453,7 +453,6 @@ async def run_task_v2_helper(
app.AGENT_FUNCTION.cleanup_element_tree_factory(),
scrape_exclude=app.scrape_exclude,
)
element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML)
if page is None:
page = await browser_state.get_working_page()
except Exception:
@@ -545,7 +544,7 @@ async def run_task_v2_helper(
workflow_permanent_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run_id,
current_url=current_url,
element_tree_in_prompt=element_tree_in_prompt,
scraped_page=scraped_page,
data_extraction_goal=plan,
task_history=task_history,
)
@@ -1084,20 +1083,22 @@ async def _generate_extraction_task(
workflow_permanent_id: str,
workflow_run_id: str,
current_url: str,
element_tree_in_prompt: str,
scraped_page: ScrapedPage,
data_extraction_goal: str,
task_history: list[dict] | None = None,
) -> tuple[ExtractionBlock, list[BLOCK_YAML_TYPES], list[PARAMETER_YAML_TYPES]]:
LOG.info("Generating extraction task", data_extraction_goal=data_extraction_goal, current_url=current_url)
# extract the data
context = skyvern_context.ensure_context()
generate_extraction_task_prompt = prompt_engine.load_prompt(
"task_v2_generate_extraction_task",
generate_extraction_task_prompt = load_prompt_with_elements(
scraped_page=scraped_page,
prompt_engine=prompt_engine,
template_name="task_v2_generate_extraction_task",
current_url=current_url,
elements=element_tree_in_prompt,
data_extraction_goal=data_extraction_goal,
local_datetime=datetime.now(context.tz_info).isoformat(),
)
generate_extraction_task_response = await app.LLM_API_HANDLER(
generate_extraction_task_prompt,
task_v2=task_v2,

View File

@@ -2,11 +2,11 @@ from typing import Any
import structlog
from skyvern.constants import DEFAULT_MAX_TOKENS
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.utils.token_counter import count_tokens
from skyvern.webeye.scraper.scraper import ScrapedPage
DEFAULT_MAX_TOKENS = 100000
LOG = structlog.get_logger()
@@ -14,13 +14,20 @@ def load_prompt_with_elements(
scraped_page: ScrapedPage,
prompt_engine: PromptEngine,
template_name: str,
html_need_skyvern_attrs: bool = True,
**kwargs: Any,
) -> str:
prompt = prompt_engine.load_prompt(template_name, elements=scraped_page.build_element_tree(), **kwargs)
prompt = prompt_engine.load_prompt(
template_name,
elements=scraped_page.build_element_tree(html_need_skyvern_attrs=html_need_skyvern_attrs),
**kwargs,
)
token_count = count_tokens(prompt)
if token_count > DEFAULT_MAX_TOKENS:
# get rid of all the secondary elements like SVG, etc
economy_elements_tree = scraped_page.build_economy_elements_tree()
economy_elements_tree = scraped_page.build_economy_elements_tree(
html_need_skyvern_attrs=html_need_skyvern_attrs
)
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs)
economy_token_count = count_tokens(prompt)
LOG.warning(
@@ -33,7 +40,10 @@ def load_prompt_with_elements(
if economy_token_count > DEFAULT_MAX_TOKENS:
# !!! HACK alert
# dump the last 1/3 of the html context and keep the first 2/3 of the html context
economy_elements_tree_dumped = scraped_page.build_economy_elements_tree(percent_to_keep=2 / 3)
economy_elements_tree_dumped = scraped_page.build_economy_elements_tree(
html_need_skyvern_attrs=html_need_skyvern_attrs,
percent_to_keep=2 / 3,
)
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree_dumped, **kwargs)
token_count_after_dump = count_tokens(prompt)
LOG.warning(

View File

@@ -67,6 +67,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants
from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.actions import actions
from skyvern.webeye.actions.actions import (
Action,
@@ -84,7 +85,6 @@ from skyvern.webeye.actions.actions import (
from skyvern.webeye.actions.responses import ActionAbort, ActionFailure, ActionResult, ActionSuccess
from skyvern.webeye.scraper.scraper import (
CleanupElementTreeFunc,
ElementTreeFormat,
IncrementalScrapePage,
ScrapedPage,
hash_element,
@@ -751,12 +751,12 @@ async def handle_input_text_action(
return [ActionSuccess()]
if not await skyvern_element.is_raw_input():
# parse the input context to help executing input action
prompt = prompt_engine.load_prompt(
"parse-input-or-select-context",
prompt = load_prompt_with_elements(
scraped_page=scraped_page,
prompt_engine=prompt_engine,
template_name="parse-input-or-select-context",
element_id=action.element_id,
action_reasoning=action.reasoning,
elements=dom.scraped_page.build_element_tree(ElementTreeFormat.HTML),
)
json_response = await app.SECONDARY_LLM_API_HANDLER(
@@ -1934,11 +1934,12 @@ async def sequentially_select_from_dropdown(
Only return the last value today
"""
prompt = prompt_engine.load_prompt(
"parse-input-or-select-context",
prompt = load_prompt_with_elements(
scraped_page=dom.scraped_page,
prompt_engine=prompt_engine,
template_name="parse-input-or-select-context",
action_reasoning=action.reasoning,
element_id=action.element_id,
elements=dom.scraped_page.build_element_tree(ElementTreeFormat.HTML),
)
json_response = await app.SECONDARY_LLM_API_HANDLER(
prompt=prompt, step=step, prompt_name="parse-input-or-select-context"
@@ -2617,11 +2618,12 @@ async def normal_select(
is_success = False
locator = skyvern_element.get_locator()
prompt = prompt_engine.load_prompt(
"parse-input-or-select-context",
prompt = load_prompt_with_elements(
scraped_page=dom.scraped_page,
prompt_engine=prompt_engine,
template_name="parse-input-or-select-context",
action_reasoning=action.reasoning,
element_id=action.element_id,
elements=dom.scraped_page.build_element_tree(ElementTreeFormat.HTML),
)
json_response = await app.SECONDARY_LLM_API_HANDLER(
prompt=prompt, step=step, prompt_name="parse-input-or-select-context"
@@ -2785,20 +2787,15 @@ async def extract_information_for_navigation_goal(
1. JSON representation of what the user is seeing
2. The scraped page
"""
prompt_template = "extract-information"
# TODO: we only use HTML element for now, introduce a way to switch in the future
element_tree_format = ElementTreeFormat.HTML
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format, html_need_skyvern_attrs=False)
scraped_page_refreshed = await scraped_page.refresh()
context = ensure_context()
extract_information_prompt = prompt_engine.load_prompt(
prompt_template,
extract_information_prompt = load_prompt_with_elements(
scraped_page=scraped_page_refreshed,
prompt_engine=prompt_engine,
template_name="extract-information",
html_need_skyvern_attrs=False,
navigation_goal=task.navigation_goal,
navigation_payload=task.navigation_payload,
elements=element_tree_in_prompt,
data_extraction_goal=task.data_extraction_goal,
extracted_information_schema=task.extracted_information_schema,
current_url=scraped_page_refreshed.url,

View File

@@ -10,10 +10,11 @@ from playwright.async_api import Frame, Locator, Page
from pydantic import BaseModel, PrivateAttr
from skyvern.config import settings
from skyvern.constants import BUILDING_ELEMENT_TREE_TIMEOUT_MS, SKYVERN_DIR, SKYVERN_ID_ATTR
from skyvern.constants import BUILDING_ELEMENT_TREE_TIMEOUT_MS, DEFAULT_MAX_TOKENS, SKYVERN_DIR, SKYVERN_ID_ATTR
from skyvern.exceptions import FailedToTakeScreenshot, ScrapingFailed, UnknownElementTreeFormat
from skyvern.forge.sdk.api.crypto import calculate_sha256
from skyvern.forge.sdk.core import skyvern_context
from skyvern.utils.token_counter import count_tokens
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.utils.page import SkyvernFrame
@@ -230,6 +231,7 @@ class ScrapedPage(BaseModel):
element_tree: list[dict]
element_tree_trimmed: list[dict]
economy_element_tree: list[dict] | None = None
last_used_element_tree: list[dict] | None = None
screenshots: list[bytes]
url: str
html: str
@@ -258,6 +260,7 @@ class ScrapedPage(BaseModel):
def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str:
self.last_used_element_tree = self.element_tree_trimmed
if fmt == ElementTreeFormat.JSON:
return json.dumps(self.element_tree_trimmed)
@@ -291,6 +294,7 @@ class ScrapedPage(BaseModel):
self.economy_element_tree = economy_elements
final_element_tree = self.economy_element_tree[: int(len(self.economy_element_tree) * percent_to_keep)]
self.last_used_element_tree = final_element_tree
if fmt == ElementTreeFormat.JSON:
return json.dumps(final_element_tree)
@@ -488,13 +492,26 @@ async def scrape_web_unsafe(
LOG.info("Waiting for 5 seconds before scraping the website.")
await asyncio.sleep(5)
screenshots = []
if take_screenshots:
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=draw_boxes)
elements, element_tree = await get_interactable_element_tree(page, scrape_exclude)
element_tree = await cleanup_element_tree(page, url, copy.deepcopy(element_tree))
element_tree_trimmed = trim_element_tree(copy.deepcopy(element_tree))
screenshots = []
if take_screenshots:
element_tree_trimmed_html_str = "".join(
json_to_html(element, need_skyvern_attrs=False) for element in element_tree_trimmed
)
token_count = count_tokens(element_tree_trimmed_html_str)
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
if token_count > DEFAULT_MAX_TOKENS:
max_screenshot_number = min(max_screenshot_number, 1)
screenshots = await SkyvernFrame.take_split_screenshots(
page=page,
url=url,
draw_boxes=draw_boxes,
max_number=max_screenshot_number,
)
id_to_css_dict, id_to_element_dict, id_to_frame_dict, id_to_element_hash, hash_to_element_ids = build_element_dict(
elements
)
@@ -524,7 +541,7 @@ async def scrape_web_unsafe(
id_to_element_hash=id_to_element_hash,
hash_to_element_ids=hash_to_element_ids,
element_tree=element_tree,
element_tree_trimmed=trim_element_tree(copy.deepcopy(element_tree)),
element_tree_trimmed=element_tree_trimmed,
screenshots=screenshots,
url=page.url,
html=html,