refactor context tree (#212)

This commit is contained in:
LawyZheng
2024-04-21 22:30:37 +08:00
committed by GitHub
parent 02db2a90e6
commit cc6ae8bae0
4 changed files with 217 additions and 63 deletions

View File

@@ -135,11 +135,13 @@ class BrowserState:
browser_context: BrowserContext | None = None,
page: Page | None = None,
browser_artifacts: BrowserArtifacts = BrowserArtifacts(),
new_context_tree: bool = False,
):
self.pw = pw
self.browser_context = browser_context
self.page = page
self.browser_artifacts = browser_artifacts
self.new_context_tree = new_context_tree
async def _close_all_other_pages(self) -> None:
if not self.browser_context or not self.page:

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import random
import structlog
from playwright.async_api import Browser, Playwright, async_playwright
@@ -23,13 +25,19 @@ class BrowserManager:
@staticmethod
async def _create_browser_state(
proxy_location: ProxyLocation | None = None, url: str | None = None
proxy_location: ProxyLocation | None = None, url: str | None = None, new_context_tree: bool = False
) -> BrowserState:
pw = await async_playwright().start()
browser_context, browser_artifacts = await BrowserContextFactory.create_browser_context(
pw, proxy_location=proxy_location, url=url
)
return BrowserState(pw=pw, browser_context=browser_context, page=None, browser_artifacts=browser_artifacts)
return BrowserState(
pw=pw,
browser_context=browser_context,
page=None,
browser_artifacts=browser_artifacts,
new_context_tree=new_context_tree,
)
async def get_or_create_for_task(self, task: Task) -> BrowserState:
if task.task_id in self.pages:
@@ -42,8 +50,11 @@ class BrowserManager:
)
self.pages[task.task_id] = self.pages[task.workflow_run_id]
return self.pages[task.task_id]
LOG.info("Creating browser state for task", task_id=task.task_id)
browser_state = await self._create_browser_state(task.proxy_location, task.url)
# TODO: percentage to use new context tree, starting from 20%
new_ctx = random.choices([False, True], weights=[0.8, 0.2], k=1)[0]
LOG.info("Creating browser state for task", task_id=task.task_id, new_ctx=new_ctx)
browser_state = await self._create_browser_state(task.proxy_location, task.url, new_ctx)
# The URL here is only used when creating a new page, and not when using an existing page.
# This will make sure browser_state.page is not None.

View File

@@ -395,6 +395,16 @@ const isComboboxDropdown = (element) => {
return role && haspopup && controls && readonly;
};
const checkParentClass = (className) => {
const targetParentClasses = ["field", "entry"];
for (let i = 0; i < targetParentClasses.length; i++) {
if (className.includes(targetParentClasses[i])) {
return true;
}
}
return false;
};
function removeMultipleSpaces(str) {
if (!str) {
return str;
@@ -408,15 +418,43 @@ function cleanupText(text) {
).trim();
}
function getElementContext(element, existingContext = "") {
// dfs to collect the non unique_id context
let fullContext = "";
if (element.childNodes.length === 0) {
return fullContext;
const checkStringIncludeRequire = (str) => {
return (
str.toLowerCase().includes("*") ||
str.toLowerCase().includes("✱") ||
str.toLowerCase().includes("require")
);
};
const checkRequiredFromStyle = (element) => {
const afterCustom = getElementComputedStyle(element, "::after")
.getPropertyValue("content")
.replace(/"/g, "");
if (checkStringIncludeRequire(afterCustom)) {
return true;
}
return element.className.toLowerCase().includes("require");
};
function getElementContext(element) {
// dfs to collect the non unique_id context
let fullContext = new Array();
// sometimes '*' shows as an after custom style
const afterCustom = getElementComputedStyle(element, "::after")
.getPropertyValue("content")
.replace(/"/g, "");
if (
afterCustom.toLowerCase().includes("*") ||
afterCustom.toLowerCase().includes("require")
) {
fullContext.push(afterCustom);
}
if (element.childNodes.length === 0) {
return fullContext.join(";");
}
let childContextList = new Array();
// if the element already has a context, then add it to the list first
if (existingContext.length > 0) childContextList.push(existingContext);
for (var child of element.childNodes) {
let childContext = "";
if (child.nodeType === Node.TEXT_NODE) {
@@ -429,19 +467,15 @@ function getElementContext(element, existingContext = "") {
}
}
if (childContext.length > 0) {
childContextList.push(childContext);
}
if (childContextList.length > 0) {
fullContext = childContextList.join(";");
fullContext.push(childContext);
}
const charLimit = 1000;
if (fullContext.length > charLimit) {
fullContext = "";
if (fullContext.join(";").length > charLimit) {
fullContext = new Array();
}
}
return fullContext;
return fullContext.join(";");
}
function getElementContent(element, skipped_element = null) {
@@ -516,7 +550,7 @@ function getListboxOptions(element) {
return selectOptions;
}
function buildTreeFromBody() {
function buildTreeFromBody(new_ctx = false) {
var elements = [];
var resultArray = [];
@@ -596,10 +630,24 @@ function buildTreeFromBody() {
attr.name === "readonly" ||
attr.name === "aria-readonly"
) {
attrValue = true;
if (attrValue && attrValue.toLowerCase() === "false") {
attrValue = false;
} else {
attrValue = true;
}
}
attrs[attr.name] = attrValue;
}
if (
new_ctx &&
checkRequiredFromStyle(element) &&
!attrs["required"] &&
!attrs["aria-required"]
) {
attrs["required"] = true;
}
if (elementTagNameLower === "input" || elementTagNameLower === "textarea") {
attrs["value"] = element.value;
}
@@ -669,6 +717,10 @@ function buildTreeFromBody() {
else {
elements[interactableParentId].children.push(elementObj);
}
// options already added to the select.options, no need to add options anymore
if (new_ctx && elementObj.options && elementObj.options.length > 0) {
return elementObj;
}
// Recursively process the children of the element
getChildElements(element).forEach((child) => {
processElement(child, elementObj.id);
@@ -684,7 +736,7 @@ function buildTreeFromBody() {
}
}
const getContextByParent = (element) => {
const getContextByParent = (element, ctx) => {
// for most elements, we're going 10 layers up to see if we can find "label" as a parent
// if found, most likely the context under label is relevant to this element
let targetParentElements = new Set(["label", "fieldset"]);
@@ -696,7 +748,10 @@ function buildTreeFromBody() {
for (var i = 0; i < 10; i++) {
parentEle = parentEle.parentElement;
if (parentEle) {
if (targetParentElements.has(parentEle.tagName.toLowerCase())) {
if (
targetParentElements.has(parentEle.tagName.toLowerCase()) ||
(new_ctx && checkParentClass(parentEle.className.toLowerCase()))
) {
targetContextualParent = parentEle;
}
} else {
@@ -704,24 +759,27 @@ function buildTreeFromBody() {
}
}
if (!targetContextualParent) {
return "";
return ctx;
}
let context = "";
var lowerCaseTagName = targetContextualParent.tagName.toLowerCase();
if (lowerCaseTagName === "label") {
context = getElementContext(targetContextualParent);
} else if (lowerCaseTagName === "fieldset") {
if (lowerCaseTagName === "fieldset") {
// fieldset is usually within a form or another element that contains the whole context
targetContextualParent = targetContextualParent.parentElement;
if (targetContextualParent) {
context = getElementContext(targetContextualParent);
}
} else {
context = getElementContext(targetContextualParent);
}
return context;
if (context.length > 0) {
ctx.push(context);
}
return ctx;
};
const getContextByLinked = (element) => {
const getContextByLinked = (element, ctx) => {
let currentEle = document.querySelector(`[unique_id="${element.id}"]`);
// check labels pointed to this element
// 1. element id -> labels pointed to this id
@@ -759,12 +817,102 @@ function buildTreeFromBody() {
}
const context = fullContext.join(";");
const charLimit = 1000;
if (context.length > charLimit) {
return "";
if (context.length > 0) {
ctx.push(context);
}
return ctx;
};
const getContextByTable = (element, ctx) => {
// pass element's parent's context to the element for listed tags
let tagsWithDirectParentContext = new Set(["a"]);
// if the element is a child of a td, th, or tr, then pass the grandparent's context to the element
let parentTagsThatDelegateParentContext = new Set(["td", "th", "tr"]);
if (tagsWithDirectParentContext.has(element.tagName)) {
let parentElement = document.querySelector(
`[unique_id="${element.id}"]`,
).parentElement;
if (!parentElement) {
return ctx;
}
if (
parentTagsThatDelegateParentContext.has(
parentElement.tagName.toLowerCase(),
)
) {
let grandParentElement = parentElement.parentElement;
if (grandParentElement) {
let context = getElementContext(grandParentElement, element.context);
if (context.length > 0) {
ctx.push(context);
}
}
}
let context = getElementContext(parentElement, element.context);
if (context.length > 0) {
ctx.push(context);
}
}
return ctx;
};
const trimDuplicatedText = (element) => {
if (element.children.length === 0 && !element.options) {
return;
}
return context;
// if the element has options, text will be duplicated with the option text
if (element.options) {
element.options.forEach((option) => {
element.text = element.text.replace(option.text, "");
});
}
// BFS to delete duplicated text
element.children.forEach((child) => {
// delete duplicated text in the tree
element.text = element.text.replace(child.text, "");
trimDuplicatedText(child);
});
// trim multiple ";"
element.text = element.text.replace(/;+/g, ";");
// trimleft and trimright ";"
element.text = element.text.replace(new RegExp(`^;+|;+$`, "g"), "");
};
const trimDuplicatedContext = (element) => {
if (element.children.length === 0) {
return;
}
// DFS to delete duplicated context
element.children.forEach((child) => {
trimDuplicatedContext(child);
if (element.context === child.context) {
delete child.context;
}
if (child.context) {
child.context = child.context.replace(element.text, "");
if (!child.context) {
delete child.context;
}
}
});
};
// some elements without children should be removed out, such as <label>
const removeOrphanNode = (results) => {
const trimmedResults = [];
for (let i = 0; i < results.length; i++) {
const element = results[i];
element.children = removeOrphanNode(element.children);
if (element.tagName === "label" && element.children.length === 0) {
continue;
}
trimmedResults.push(element);
}
return trimmedResults;
};
// TODO: Handle iframes
@@ -788,43 +936,36 @@ function buildTreeFromBody() {
);
}
const context = getContextByLinked(element) + getContextByParent(element);
let ctxList = [];
ctxList = getContextByLinked(element, ctxList);
ctxList = getContextByParent(element, ctxList);
ctxList = getContextByTable(element, ctxList);
const context = ctxList.join(";");
// const context = getContextByParent(element)
if (context && context.length <= 1000) {
element.context = context;
}
// pass element's parent's context to the element for listed tags
let tagsWithDirectParentContext = new Set(["a"]);
// if the element is a child of a td, th, or tr, then pass the grandparent's context to the element
let parentTagsThatDelegateParentContext = new Set(["td", "th", "tr"]);
if (tagsWithDirectParentContext.has(element.tagName)) {
let parentElement = document.querySelector(
`[unique_id="${element.id}"]`,
).parentElement;
if (!parentElement) {
continue;
}
if (new_ctx && checkStringIncludeRequire(context)) {
if (
parentTagsThatDelegateParentContext.has(
parentElement.tagName.toLowerCase(),
)
!element.attributes["required"] &&
!element.attributes["aria-required"]
) {
let grandParentElement = parentElement.parentElement;
if (grandParentElement) {
let context = getElementContext(grandParentElement, element.context);
if (context.length > 0) {
element.context = context;
}
}
}
let context = getElementContext(parentElement, element.context);
if (context.length > 0) {
element.context = context;
element.attributes["required"] = true;
}
}
}
if (!new_ctx) {
return [elements, resultArray];
}
resultArray = removeOrphanNode(resultArray);
resultArray.forEach((root) => {
trimDuplicatedText(root);
trimDuplicatedContext(root);
});
return [elements, resultArray];
}

View File

@@ -184,7 +184,7 @@ async def scrape_web_unsafe(
await remove_bounding_boxes(page)
await scroll_to_top(page, drow_boxes=False)
elements, element_tree = await get_interactable_element_tree(page)
elements, element_tree = await get_interactable_element_tree(page, browser_state.new_context_tree)
element_tree = cleanup_elements(copy.deepcopy(element_tree))
_build_element_links(elements)
@@ -211,15 +211,15 @@ async def scrape_web_unsafe(
)
async def get_interactable_element_tree(page: Page) -> tuple[list[dict], list[dict]]:
async def get_interactable_element_tree(page: Page, new_context_tree: bool) -> tuple[list[dict], list[dict]]:
"""
Get the element tree of the page, including all the elements that are interactable.
:param page: Page instance to get the element tree from.
:return: Tuple containing the element tree and a map of element IDs to elements.
"""
await page.evaluate(JS_FUNCTION_DEFS)
js_script = "() => buildTreeFromBody()"
elements, element_tree = await page.evaluate(js_script)
js_script = "(new_ctx) => buildTreeFromBody(new_ctx)"
elements, element_tree = await page.evaluate(js_script, new_context_tree)
return elements, element_tree