Extract BrowserState.scrape_website (#4184)

This commit is contained in:
Stanislav Novosad
2025-12-03 15:08:32 -07:00
committed by GitHub
parent ce01f2cb35
commit f754272f9c
16 changed files with 375 additions and 313 deletions

View File

@@ -1,15 +1,11 @@
import asyncio
import copy
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import StrEnum
from typing import Any, Awaitable, Callable, Self
import structlog
from playwright._impl._errors import TimeoutError
from playwright.async_api import ElementHandle, Frame, Locator, Page
from pydantic import BaseModel, PrivateAttr
from skyvern.config import settings
from skyvern.constants import DEFAULT_MAX_TOKENS, SKYVERN_DIR, SKYVERN_ID_ATTR
@@ -28,12 +24,17 @@ from skyvern.forge.sdk.trace import TraceManager
from skyvern.utils.image_resizer import Resolution
from skyvern.utils.token_counter import count_tokens
from skyvern.webeye.browser_state import BrowserState
from skyvern.webeye.scraper.scraped_page import (
CleanupElementTreeFunc,
ElementTreeBuilder,
ElementTreeFormat,
ScrapedPage,
ScrapeExcludeFunc,
json_to_html,
)
from skyvern.webeye.utils.page import SkyvernFrame
LOG = structlog.get_logger()
CleanupElementTreeFunc = Callable[[Page | Frame, str, list[dict]], Awaitable[list[dict]]]
ScrapeExcludeFunc = Callable[[Page, Frame], Awaitable[bool]]
RESERVED_ATTRIBUTES = {
"accept", # for input file
"alt",
@@ -75,11 +76,6 @@ BASE64_INCLUDE_ATTRIBUTES = {
}
ELEMENT_NODE_ATTRIBUTES = {
"id",
}
def load_js_script() -> str:
# TODO: Handle file location better. This is a hacky way to find the file location.
path = f"{SKYVERN_DIR}/webeye/scraper/domUtils.js"
@@ -96,86 +92,6 @@ def load_js_script() -> str:
JS_FUNCTION_DEFS = load_js_script()
# function to convert JSON element to HTML
def build_attribute(key: str, value: Any) -> str:
if isinstance(value, bool) or isinstance(value, int):
return f'{key}="{str(value).lower()}"'
return f'{key}="{str(value)}"' if value else key
def json_to_html(element: dict, need_skyvern_attrs: bool = True) -> str:
"""
if element is flagged as dropped, the html format is empty
"""
tag = element["tagName"]
attributes: dict[str, Any] = copy.deepcopy(element.get("attributes", {}))
interactable = element.get("interactable", False)
if element.get("isDropped", False):
if not interactable:
return ""
else:
LOG.debug("Element is interactable. Trimmed all attributes instead of dropping it", element=element)
attributes = {}
context = skyvern_context.ensure_context()
# FIXME: Theoretically, all href links with over 69(64+1+4) length could be hashed
# but currently, just hash length>150 links to confirm the solution goes well
if "href" in attributes and len(attributes.get("href", "")) > 150:
href = attributes.get("href", "")
# jinja style can't accept the variable name starts with number
# adding "_" to make sure the variable name is valid.
hashed_href = "_" + calculate_sha256(href)
context.hashed_href_map[hashed_href] = href
attributes["href"] = "{{" + hashed_href + "}}"
if need_skyvern_attrs:
# adding the node attribute to attributes
for attr in ELEMENT_NODE_ATTRIBUTES:
value = element.get(attr)
if value is None:
continue
attributes[attr] = value
attributes_html = " ".join(build_attribute(key, value) for key, value in attributes.items())
if element.get("isSelectable", False):
tag = "select"
text = element.get("text", "")
# build children HTML
children_html = "".join(
json_to_html(child, need_skyvern_attrs=need_skyvern_attrs) for child in element.get("children", [])
)
# build option HTML
option_html = "".join(
f'<option index="{option.get("optionIndex")}">{option.get("text")}</option>'
if option.get("text")
else f'<option index="{option.get("optionIndex")}" value="{option.get("value")}">{option.get("text")}</option>'
for option in element.get("options", [])
)
if element.get("purgeable", False):
return children_html + option_html
before_pseudo_text = element.get("beforePseudoText") or ""
after_pseudo_text = element.get("afterPseudoText") or ""
# Check if the element is self-closing
if (
tag in ["img", "input", "br", "hr", "meta", "link"]
and not option_html
and not children_html
and not before_pseudo_text
and not after_pseudo_text
):
return f"<{tag}{attributes_html if not attributes_html else ' ' + attributes_html}>"
else:
return f"<{tag}{attributes_html if not attributes_html else ' ' + attributes_html}>{before_pseudo_text}{text}{children_html + option_html}{after_pseudo_text}</{tag}>"
def clean_element_before_hashing(element: dict) -> dict:
def clean_nested(element: dict) -> dict:
element_cleaned = {key: value for key, value in element.items() if key not in {"id", "rect", "frame_index"}}
@@ -220,198 +136,6 @@ def build_element_dict(
return id_to_css_dict, id_to_element_dict, id_to_frame_dict, id_to_element_hash, hash_to_element_ids
class ElementTreeFormat(StrEnum):
JSON = "json" # deprecate JSON format soon. please use HTML format
HTML = "html"
class ElementTreeBuilder(ABC):
@abstractmethod
def support_economy_elements_tree(self) -> bool:
pass
@abstractmethod
def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str:
pass
@abstractmethod
def build_economy_elements_tree(
self,
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
html_need_skyvern_attrs: bool = True,
percent_to_keep: float = 1,
) -> str:
pass
class ScrapedPage(BaseModel, ElementTreeBuilder):
"""
Scraped response from a webpage, including:
1. List of elements
2. ID to css map
3. The element tree of the page (list of dicts). Each element has children and attributes.
4. The screenshot (base64 encoded)
5. The URL of the page
6. The HTML of the page
7. The extracted text from the page
"""
elements: list[dict]
id_to_element_dict: dict[str, dict] = {}
id_to_frame_dict: dict[str, str] = {}
id_to_css_dict: dict[str, str] = {}
id_to_element_hash: dict[str, str] = {}
hash_to_element_ids: dict[str, list[str]] = {}
element_tree: list[dict]
element_tree_trimmed: list[dict]
economy_element_tree: list[dict] | None = None
last_used_element_tree: list[dict] | None = None
screenshots: list[bytes] = []
url: str = ""
html: str = ""
extracted_text: str | None = None
window_dimension: dict[str, int] | None = None
_browser_state: BrowserState = PrivateAttr()
_clean_up_func: CleanupElementTreeFunc = PrivateAttr()
_scrape_exclude: ScrapeExcludeFunc | None = PrivateAttr(default=None)
def __init__(self, **data: Any) -> None:
missing_attrs = [attr for attr in ["_browser_state", "_clean_up_func"] if attr not in data]
if len(missing_attrs) > 0:
raise ValueError(f"Missing required private attributes: {', '.join(missing_attrs)}")
# popup private attributes
browser_state = data.pop("_browser_state")
clean_up_func = data.pop("_clean_up_func")
scrape_exclude = data.pop("_scrape_exclude")
super().__init__(**data)
self._browser_state = browser_state
self._clean_up_func = clean_up_func
self._scrape_exclude = scrape_exclude
def support_economy_elements_tree(self) -> bool:
return True
def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str:
self.last_used_element_tree = self.element_tree_trimmed
if fmt == ElementTreeFormat.JSON:
return json.dumps(self.element_tree_trimmed)
if fmt == ElementTreeFormat.HTML:
return "".join(
json_to_html(element, need_skyvern_attrs=html_need_skyvern_attrs)
for element in self.element_tree_trimmed
)
raise UnknownElementTreeFormat(fmt=fmt)
def build_economy_elements_tree(
self,
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
html_need_skyvern_attrs: bool = True,
percent_to_keep: float = 1,
) -> str:
"""
Economy elements tree doesn't include secondary elements like SVG, etc
"""
if not self.economy_element_tree:
economy_elements = []
copied_element_tree_trimmed = copy.deepcopy(self.element_tree_trimmed)
# Process each root element
for root_element in copied_element_tree_trimmed:
processed_element = self._process_element_for_economy_tree(root_element)
if processed_element:
economy_elements.append(processed_element)
self.economy_element_tree = economy_elements
self.last_used_element_tree = self.economy_element_tree
if fmt == ElementTreeFormat.JSON:
element_str = json.dumps(self.economy_element_tree)
return element_str[: int(len(element_str) * percent_to_keep)]
if fmt == ElementTreeFormat.HTML:
element_str = "".join(
json_to_html(element, need_skyvern_attrs=html_need_skyvern_attrs)
for element in self.economy_element_tree
)
return element_str[: int(len(element_str) * percent_to_keep)]
raise UnknownElementTreeFormat(fmt=fmt)
def _process_element_for_economy_tree(self, element: dict) -> dict | None:
"""
Helper method to process an element for the economy tree using BFS.
Removes SVG elements and their children.
"""
# Skip SVG elements entirely
if element.get("tagName", "").lower() == "svg":
return None
# Process children using BFS
if "children" in element:
new_children = []
for child in element["children"]:
processed_child = self._process_element_for_economy_tree(child)
if processed_child:
new_children.append(processed_child)
element["children"] = new_children
return element
async def refresh(self, draw_boxes: bool = True, scroll: bool = True, max_retries: int = 0) -> Self:
refreshed_page = await scrape_website(
browser_state=self._browser_state,
url=self.url,
cleanup_element_tree=self._clean_up_func,
max_retries=max_retries,
scrape_exclude=self._scrape_exclude,
draw_boxes=draw_boxes,
scroll=scroll,
)
self.elements = refreshed_page.elements
self.id_to_css_dict = refreshed_page.id_to_css_dict
self.id_to_element_dict = refreshed_page.id_to_element_dict
self.id_to_frame_dict = refreshed_page.id_to_frame_dict
self.id_to_element_hash = refreshed_page.id_to_element_hash
self.hash_to_element_ids = refreshed_page.hash_to_element_ids
self.element_tree = refreshed_page.element_tree
self.element_tree_trimmed = refreshed_page.element_tree_trimmed
self.screenshots = refreshed_page.screenshots or self.screenshots
self.html = refreshed_page.html
self.extracted_text = refreshed_page.extracted_text
self.url = refreshed_page.url
return self
async def generate_scraped_page(
self,
draw_boxes: bool = True,
scroll: bool = True,
take_screenshots: bool = True,
max_retries: int = 0,
) -> Self:
return await scrape_website(
browser_state=self._browser_state,
url=self.url,
cleanup_element_tree=self._clean_up_func,
max_retries=max_retries,
scrape_exclude=self._scrape_exclude,
take_screenshots=take_screenshots,
draw_boxes=draw_boxes,
scroll=scroll,
)
async def generate_scraped_page_without_screenshots(self, max_retries: int = 0) -> Self:
return await self.generate_scraped_page(take_screenshots=False, max_retries=max_retries)
@TraceManager.traced_async(ignore_input=True)
async def scrape_website(
browser_state: BrowserState,
@@ -557,6 +281,7 @@ async def scrape_web_unsafe(
:return: Tuple containing Page instance, base64 encoded screenshot, and page elements.
:note: This function does not handle exceptions. Ensure proper error handling in the calling context.
"""
# browser state must have the page instance, otherwise we should not do scraping
page = await browser_state.must_get_working_page()
# Take screenshots of the page with the bounding boxes. We will remove the bounding boxes later.