remove screenshot when check user goal (#1147)

This commit is contained in:
LawyZheng
2024-11-06 23:20:45 +08:00
committed by GitHub
parent c084764373
commit b62c2caae0
3 changed files with 94 additions and 45 deletions

View File

@@ -838,38 +838,26 @@ class ForgeAgent:
task_completes_on_download = task_block and task_block.complete_on_download and task.workflow_run_id task_completes_on_download = task_block and task_block.complete_on_download and task.workflow_run_id
if not has_decisive_action and not task_completes_on_download: if not has_decisive_action and not task_completes_on_download:
LOG.info("Checking if user goal is achieved after re-scraping the page") working_page = await browser_state.must_get_working_page()
# Check if navigation goal is achieved after re-scraping the page complete_action = await self.check_user_goal_complete(
new_scraped_page = await self._scrape_with_type( page=working_page,
scraped_page=scraped_page,
task=task, task=task,
step=step, step=step,
browser_state=browser_state,
scrape_type=ScrapeType.NORMAL,
organization=organization,
) )
if new_scraped_page is None: if complete_action is not None:
LOG.warning("Failed to scrape the page before checking user goal success, skipping check...") LOG.info("User goal achieved, executing complete action")
else: complete_action.organization_id = task.organization_id
working_page = await browser_state.must_get_working_page() complete_action.workflow_run_id = task.workflow_run_id
complete_action = await self.check_user_goal_complete( complete_action.task_id = task.task_id
page=working_page, complete_action.step_id = step.step_id
scraped_page=new_scraped_page, complete_action.step_order = step.order
task=task, complete_action.action_order = len(detailed_agent_step_output.actions_and_results)
step=step, complete_results = await ActionHandler.handle_action(
scraped_page, task, step, working_page, complete_action
) )
if complete_action is not None: detailed_agent_step_output.actions_and_results.append((complete_action, complete_results))
LOG.info("User goal achieved, executing complete action") await self.record_artifacts_after_action(task, step, browser_state)
complete_action.organization_id = task.organization_id
complete_action.workflow_run_id = task.workflow_run_id
complete_action.task_id = task.task_id
complete_action.step_id = step.step_id
complete_action.step_order = step.order
complete_action.action_order = len(detailed_agent_step_output.actions_and_results)
complete_results = await ActionHandler.handle_action(
scraped_page, task, step, working_page, complete_action
)
detailed_agent_step_output.actions_and_results.append((complete_action, complete_results))
await self.record_artifacts_after_action(task, step, browser_state)
# If no action errors return the agent state and output # If no action errors return the agent state and output
completed_step = await self.update_step( completed_step = await self.update_step(
step=step, step=step,
@@ -913,11 +901,19 @@ class ForgeAgent:
page: Page, scraped_page: ScrapedPage, task: Task, step: Step page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> CompleteAction | None: ) -> CompleteAction | None:
try: try:
LOG.info(
"Checking if user goal is achieved after re-scraping the page without screenshots",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
scraped_page_without_screenshots = await scraped_page.refresh(with_screenshot=False)
verification_prompt = prompt_engine.load_prompt( verification_prompt = prompt_engine.load_prompt(
"check-user-goal", "check-user-goal",
navigation_goal=task.navigation_goal, navigation_goal=task.navigation_goal,
navigation_payload=task.navigation_payload, navigation_payload=task.navigation_payload,
elements=scraped_page.build_element_tree(ElementTreeFormat.HTML), elements=scraped_page_without_screenshots.build_element_tree(ElementTreeFormat.HTML),
) )
# this prompt is critical to our agent so let's use the primary LLM API handler # this prompt is critical to our agent so let's use the primary LLM API handler
@@ -926,6 +922,9 @@ class ForgeAgent:
LOG.error( LOG.error(
"Invalid LLM response for user goal success verification, skipping verification", "Invalid LLM response for user goal success verification, skipping verification",
verification_response=verification_response, verification_response=verification_response,
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
) )
return None return None
@@ -940,7 +939,13 @@ class ForgeAgent:
) )
except Exception: except Exception:
LOG.error("LLM verification failed for complete action, skipping LLM verification", exc_info=True) LOG.error(
"LLM verification failed for complete action, skipping LLM verification",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
exc_info=True,
)
return None return None
async def record_artifacts_after_action(self, task: Task, step: Step, browser_state: BrowserState) -> None: async def record_artifacts_after_action(self, task: Task, step: Step, browser_state: BrowserState) -> None:
@@ -1039,7 +1044,7 @@ class ForgeAgent:
browser_state: BrowserState, browser_state: BrowserState,
scrape_type: ScrapeType, scrape_type: ScrapeType,
organization: Organization | None = None, organization: Organization | None = None,
) -> ScrapedPage | None: ) -> ScrapedPage:
if scrape_type == ScrapeType.NORMAL: if scrape_type == ScrapeType.NORMAL:
pass pass

View File

@@ -1,7 +1,6 @@
from typing import Awaitable, Callable from typing import Awaitable, Callable
from fastapi import FastAPI from fastapi import FastAPI
from playwright.async_api import Frame, Page
from skyvern.forge.agent import ForgeAgent from skyvern.forge.agent import ForgeAgent
from skyvern.forge.agent_functions import AgentFunction from skyvern.forge.agent_functions import AgentFunction
@@ -17,6 +16,7 @@ from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
from skyvern.forge.sdk.workflow.service import WorkflowService from skyvern.forge.sdk.workflow.service import WorkflowService
from skyvern.webeye.browser_manager import BrowserManager from skyvern.webeye.browser_manager import BrowserManager
from skyvern.webeye.scraper.scraper import ScrapeExcludeFunc
SETTINGS_MANAGER = SettingsManager.get_settings() SETTINGS_MANAGER = SettingsManager.get_settings()
DATABASE = AgentDB( DATABASE = AgentDB(
@@ -37,7 +37,7 @@ SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager() WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
WORKFLOW_SERVICE = WorkflowService() WORKFLOW_SERVICE = WorkflowService()
AGENT_FUNCTION = AgentFunction() AGENT_FUNCTION = AgentFunction()
scrape_exclude: Callable[[Page, Frame], Awaitable[bool]] | None = None scrape_exclude: ScrapeExcludeFunc | None = None
authentication_function: Callable[[str], Awaitable[Organization]] | None = None authentication_function: Callable[[str], Awaitable[Organization]] | None = None
setup_api_app: Callable[[FastAPI], None] | None = None setup_api_app: Callable[[FastAPI], None] | None = None

View File

@@ -3,11 +3,11 @@ import copy
import json import json
from collections import defaultdict from collections import defaultdict
from enum import StrEnum from enum import StrEnum
from typing import Any, Awaitable, Callable from typing import Any, Awaitable, Callable, Self
import structlog import structlog
from playwright.async_api import Frame, Locator, Page from playwright.async_api import Frame, Locator, Page
from pydantic import BaseModel from pydantic import BaseModel, PrivateAttr
from skyvern.constants import BUILDING_ELEMENT_TREE_TIMEOUT_MS, SKYVERN_DIR, SKYVERN_ID_ATTR from skyvern.constants import BUILDING_ELEMENT_TREE_TIMEOUT_MS, SKYVERN_DIR, SKYVERN_ID_ATTR
from skyvern.exceptions import FailedToTakeScreenshot, UnknownElementTreeFormat from skyvern.exceptions import FailedToTakeScreenshot, UnknownElementTreeFormat
@@ -18,6 +18,7 @@ from skyvern.webeye.utils.page import SkyvernFrame
LOG = structlog.get_logger() LOG = structlog.get_logger()
CleanupElementTreeFunc = Callable[[Page | Frame, str, list[dict]], Awaitable[list[dict]]] CleanupElementTreeFunc = Callable[[Page | Frame, str, list[dict]], Awaitable[list[dict]]]
ScrapeExcludeFunc = Callable[[Page, Frame], Awaitable[bool]]
RESERVED_ATTRIBUTES = { RESERVED_ATTRIBUTES = {
"accept", # for input file "accept", # for input file
@@ -211,6 +212,26 @@ class ScrapedPage(BaseModel):
html: str html: str
extracted_text: str | None = None extracted_text: str | None = None
_browser_state: BrowserState = PrivateAttr()
_clean_up_func: CleanupElementTreeFunc = PrivateAttr()
_scrape_exclude: ScrapeExcludeFunc | None = PrivateAttr(default=None)
def __init__(self, **data: Any) -> None:
missing_attrs = [attr for attr in ["_browser_state", "_clean_up_func"] if attr not in data]
if len(missing_attrs) > 0:
raise ValueError(f"Missing required private attributes: {', '.join(missing_attrs)}")
# popup private attributes
browser_state = data.pop("_browser_state")
clean_up_func = data.pop("_clean_up_func")
scrape_exclude = data.pop("_scrape_exclude")
super().__init__(**data)
self._browser_state = browser_state
self._clean_up_func = clean_up_func
self._scrape_exclude = scrape_exclude
def build_element_tree(self, fmt: ElementTreeFormat = ElementTreeFormat.JSON) -> str: def build_element_tree(self, fmt: ElementTreeFormat = ElementTreeFormat.JSON) -> str:
if fmt == ElementTreeFormat.JSON: if fmt == ElementTreeFormat.JSON:
return json.dumps(self.element_tree_trimmed) return json.dumps(self.element_tree_trimmed)
@@ -220,13 +241,23 @@ class ScrapedPage(BaseModel):
raise UnknownElementTreeFormat(fmt=fmt) raise UnknownElementTreeFormat(fmt=fmt)
async def refresh(self, with_screenshot: bool = True) -> Self:
return await scrape_website(
browser_state=self._browser_state,
url=self.url,
cleanup_element_tree=self._clean_up_func,
scrape_exclude=self._scrape_exclude,
with_screenshot=with_screenshot,
)
async def scrape_website( async def scrape_website(
browser_state: BrowserState, browser_state: BrowserState,
url: str, url: str,
cleanup_element_tree: CleanupElementTreeFunc, cleanup_element_tree: CleanupElementTreeFunc,
num_retry: int = 0, num_retry: int = 0,
scrape_exclude: Callable[[Page, Frame], Awaitable[bool]] | None = None, scrape_exclude: ScrapeExcludeFunc | None = None,
with_screenshot: bool = True,
) -> ScrapedPage: ) -> ScrapedPage:
""" """
************************************************************************************************ ************************************************************************************************
@@ -251,7 +282,13 @@ async def scrape_website(
""" """
try: try:
num_retry += 1 num_retry += 1
return await scrape_web_unsafe(browser_state, url, cleanup_element_tree, scrape_exclude) return await scrape_web_unsafe(
browser_state=browser_state,
url=url,
cleanup_element_tree=cleanup_element_tree,
scrape_exclude=scrape_exclude,
with_screenshot=with_screenshot,
)
except Exception as e: except Exception as e:
# NOTE: MAX_SCRAPING_RETRIES is set to 0 in both staging and production # NOTE: MAX_SCRAPING_RETRIES is set to 0 in both staging and production
if num_retry > SettingsManager.get_settings().MAX_SCRAPING_RETRIES: if num_retry > SettingsManager.get_settings().MAX_SCRAPING_RETRIES:
@@ -272,6 +309,7 @@ async def scrape_website(
cleanup_element_tree, cleanup_element_tree,
num_retry=num_retry, num_retry=num_retry,
scrape_exclude=scrape_exclude, scrape_exclude=scrape_exclude,
with_screenshot=with_screenshot,
) )
@@ -318,7 +356,8 @@ async def scrape_web_unsafe(
browser_state: BrowserState, browser_state: BrowserState,
url: str, url: str,
cleanup_element_tree: CleanupElementTreeFunc, cleanup_element_tree: CleanupElementTreeFunc,
scrape_exclude: Callable[[Page, Frame], Awaitable[bool]] | None = None, scrape_exclude: ScrapeExcludeFunc | None = None,
with_screenshot: bool = True,
) -> ScrapedPage: ) -> ScrapedPage:
""" """
Asynchronous function that performs web scraping without any built-in error handling. This function is intended Asynchronous function that performs web scraping without any built-in error handling. This function is intended
@@ -331,10 +370,8 @@ async def scrape_web_unsafe(
:return: Tuple containing Page instance, base64 encoded screenshot, and page elements. :return: Tuple containing Page instance, base64 encoded screenshot, and page elements.
:note: This function does not handle exceptions. Ensure proper error handling in the calling context. :note: This function does not handle exceptions. Ensure proper error handling in the calling context.
""" """
# We only create a new page if one does not exist. This is to allow keeping the same page since we want to # browser state must have the page instance, otherwise we should not do scraping
# continue working on the same page that we're taking actions on. page = await browser_state.must_get_working_page()
# *This also means URL is only used when creating a new page, and not when using an existing page.
page = await browser_state.get_or_create_page(url)
# Take screenshots of the page with the bounding boxes. We will remove the bounding boxes later. # Take screenshots of the page with the bounding boxes. We will remove the bounding boxes later.
# Scroll to the top of the page and take a screenshot. # Scroll to the top of the page and take a screenshot.
# Scroll to the next page and take a screenshot until we reach the end of the page. # Scroll to the next page and take a screenshot until we reach the end of the page.
@@ -345,7 +382,11 @@ async def scrape_web_unsafe(
LOG.info("Waiting for 5 seconds before scraping the website.") LOG.info("Waiting for 5 seconds before scraping the website.")
await asyncio.sleep(5) await asyncio.sleep(5)
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=True) screenshots: list[bytes] = []
# TODO: do we need to scroll to the button when we scrape without screenshots?
if with_screenshot:
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=True)
elements, element_tree = await get_interactable_element_tree(page, scrape_exclude) elements, element_tree = await get_interactable_element_tree(page, scrape_exclude)
element_tree = await cleanup_element_tree(page, url, copy.deepcopy(element_tree)) element_tree = await cleanup_element_tree(page, url, copy.deepcopy(element_tree))
@@ -384,6 +425,9 @@ async def scrape_web_unsafe(
url=page.url, url=page.url,
html=html, html=html,
extracted_text=text_content, extracted_text=text_content,
_browser_state=browser_state,
_clean_up_func=cleanup_element_tree,
_scrape_exclude=scrape_exclude,
) )
@@ -391,7 +435,7 @@ async def get_interactable_element_tree_in_frame(
frames: list[Frame], frames: list[Frame],
elements: list[dict], elements: list[dict],
element_tree: list[dict], element_tree: list[dict],
scrape_exclude: Callable[[Page, Frame], Awaitable[bool]] | None = None, scrape_exclude: ScrapeExcludeFunc | None = None,
) -> tuple[list[dict], list[dict]]: ) -> tuple[list[dict], list[dict]]:
for frame in frames: for frame in frames:
if frame.is_detached(): if frame.is_detached():
@@ -445,7 +489,7 @@ async def get_interactable_element_tree_in_frame(
async def get_interactable_element_tree( async def get_interactable_element_tree(
page: Page, page: Page,
scrape_exclude: Callable[[Page, Frame], Awaitable[bool]] | None = None, scrape_exclude: ScrapeExcludeFunc | None = None,
) -> tuple[list[dict], list[dict]]: ) -> tuple[list[dict], list[dict]]:
""" """
Get the element tree of the page, including all the elements that are interactable. Get the element tree of the page, including all the elements that are interactable.