convert element tree json -> html (#299)

This commit is contained in:
LawyZheng
2024-05-13 09:37:17 +08:00
committed by GitHub
parent 270642c60c
commit 25311dee86
3 changed files with 82 additions and 2 deletions

View File

@@ -217,3 +217,8 @@ class BitwardenTOTPError(BitwardenBaseError):
class BitwardenLogoutError(BitwardenBaseError): class BitwardenLogoutError(BitwardenBaseError):
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
super().__init__(f"Error logging out of Bitwarden: {message}") super().__init__(f"Error logging out of Bitwarden: {message}")
class UnknownElementTreeFormat(SkyvernException):
def __init__(self, fmt: str) -> None:
super().__init__(f"Unknown element tree format {fmt}")

View File

@@ -43,7 +43,7 @@ from skyvern.webeye.actions.handler import ActionHandler
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.responses import ActionResult from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -636,13 +636,28 @@ class ForgeAgent:
): ):
LOG.info("Using Claude3 Sonnet prompt template for action extraction") LOG.info("Using Claude3 Sonnet prompt template for action extraction")
prompt_template = "extract-action-claude3-sonnet" prompt_template = "extract-action-claude3-sonnet"
element_tree_format = ElementTreeFormat.JSON
if app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
"USE_HTML_ELEMENT_TREE",
task.workflow_run_id or task.task_id,
properties={"organization_id": task.organization_id},
):
element_tree_format = ElementTreeFormat.HTML
LOG.info(
f"Building element tree",
task_id=task.task_id,
workflow_run_id=task.workflow_run_id,
format=element_tree_format,
)
extract_action_prompt = prompt_engine.load_prompt( extract_action_prompt = prompt_engine.load_prompt(
prompt_template, prompt_template,
navigation_goal=navigation_goal, navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(task.navigation_payload), navigation_payload_str=json.dumps(task.navigation_payload),
starting_url=starting_url, starting_url=starting_url,
current_url=current_url, current_url=current_url,
elements=scraped_page.element_tree_trimmed, elements=scraped_page.build_element_tree(element_tree_format),
data_extraction_goal=task.data_extraction_goal, data_extraction_goal=task.data_extraction_goal,
action_history=actions_and_results_str, action_history=actions_and_results_str,
error_code_mapping_str=json.dumps(task.error_code_mapping) if task.error_code_mapping else None, error_code_mapping_str=json.dumps(task.error_code_mapping) if task.error_code_mapping else None,

View File

@@ -1,12 +1,16 @@
import asyncio import asyncio
import copy import copy
import json
from collections import defaultdict from collections import defaultdict
from enum import StrEnum
from typing import Any
import structlog import structlog
from playwright.async_api import Page from playwright.async_api import Page
from pydantic import BaseModel from pydantic import BaseModel
from skyvern.constants import SKYVERN_DIR, SKYVERN_ID_ATTR from skyvern.constants import SKYVERN_DIR, SKYVERN_ID_ATTR
from skyvern.exceptions import UnknownElementTreeFormat
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.browser_factory import BrowserState
@@ -39,6 +43,11 @@ RESERVED_ATTRIBUTES = {
"value", "value",
} }
ELEMENT_NODE_ATTRIBUTES = {
"id",
"interactable",
}
def load_js_script() -> str: def load_js_script() -> str:
# TODO: Handle file location better. This is a hacky way to find the file location. # TODO: Handle file location better. This is a hacky way to find the file location.
@@ -56,6 +65,48 @@ def load_js_script() -> str:
JS_FUNCTION_DEFS = load_js_script() JS_FUNCTION_DEFS = load_js_script()
# function to convert JSON element to HTML
def build_attribute(key: str, value: Any) -> str:
if isinstance(value, bool) or isinstance(value, int):
return f'{key}="{str(value).lower()}"'
return f'{key}="{str(value)}"' if value else key
def json_to_html(element: dict) -> str:
attributes: dict[str, Any] = element.get("attributes", {})
# adding the node attribute to attributes
for attr in ELEMENT_NODE_ATTRIBUTES:
value = element.get(attr)
if value is None:
continue
attributes[attr] = value
attributes_html = " ".join(build_attribute(key, value) for key, value in attributes.items())
tag = element["tagName"]
text = element.get("text", "")
# build children HTML
children_html = "".join(json_to_html(child) for child in element.get("children", []))
# build option HTML
option_html = "".join(
f'<option index="{option.get("optionIndex")}">{option.get("text")}</option>'
for option in element.get("options", [])
)
# Check if the element is self-closing
if tag in ["img", "input", "br", "hr", "meta", "link"]:
return f'<{tag}{attributes_html if not attributes_html else " "+attributes_html}>'
else:
return f'<{tag}{attributes_html if not attributes_html else " "+attributes_html}>{text}{children_html+option_html}</{tag}>'
class ElementTreeFormat(StrEnum):
JSON = "json"
HTML = "html"
class ScrapedPage(BaseModel): class ScrapedPage(BaseModel):
""" """
Scraped response from a webpage, including: Scraped response from a webpage, including:
@@ -78,6 +129,15 @@ class ScrapedPage(BaseModel):
html: str html: str
extracted_text: str | None = None extracted_text: str | None = None
def build_element_tree(self, fmt: ElementTreeFormat = ElementTreeFormat.JSON) -> str:
if fmt == ElementTreeFormat.JSON:
return json.dumps(self.element_tree_trimmed)
if fmt == ElementTreeFormat.HTML:
return "".join(json_to_html(element) for element in self.element_tree_trimmed)
raise UnknownElementTreeFormat(fmt=fmt)
async def scrape_website( async def scrape_website(
browser_state: BrowserState, browser_state: BrowserState,