diff --git a/skyvern/constants.py b/skyvern/constants.py index de23234a..372868c6 100644 --- a/skyvern/constants.py +++ b/skyvern/constants.py @@ -32,3 +32,4 @@ class ScrapeType(StrEnum): SCRAPE_TYPE_ORDER = [ScrapeType.NORMAL, ScrapeType.NORMAL, ScrapeType.RELOAD] +DEFAULT_MAX_TOKENS = 100000 diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index d8f21106..c842c9f6 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -55,7 +55,7 @@ from skyvern.forge.sdk.workflow.models.yaml import ( from skyvern.schemas.runs import ProxyLocation, RunType from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.webeye.browser_factory import BrowserState -from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website +from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website from skyvern.webeye.utils.page import SkyvernFrame LOG = structlog.get_logger() @@ -453,7 +453,6 @@ async def run_task_v2_helper( app.AGENT_FUNCTION.cleanup_element_tree_factory(), scrape_exclude=app.scrape_exclude, ) - element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML) if page is None: page = await browser_state.get_working_page() except Exception: @@ -545,7 +544,7 @@ async def run_task_v2_helper( workflow_permanent_id=workflow.workflow_permanent_id, workflow_run_id=workflow_run_id, current_url=current_url, - element_tree_in_prompt=element_tree_in_prompt, + scraped_page=scraped_page, data_extraction_goal=plan, task_history=task_history, ) @@ -1084,20 +1083,22 @@ async def _generate_extraction_task( workflow_permanent_id: str, workflow_run_id: str, current_url: str, - element_tree_in_prompt: str, + scraped_page: ScrapedPage, data_extraction_goal: str, task_history: list[dict] | None = None, ) -> tuple[ExtractionBlock, list[BLOCK_YAML_TYPES], list[PARAMETER_YAML_TYPES]]: LOG.info("Generating extraction task", data_extraction_goal=data_extraction_goal, current_url=current_url) # extract the data context = skyvern_context.ensure_context() - generate_extraction_task_prompt = prompt_engine.load_prompt( - "task_v2_generate_extraction_task", + generate_extraction_task_prompt = load_prompt_with_elements( + scraped_page=scraped_page, + prompt_engine=prompt_engine, + template_name="task_v2_generate_extraction_task", current_url=current_url, - elements=element_tree_in_prompt, data_extraction_goal=data_extraction_goal, local_datetime=datetime.now(context.tz_info).isoformat(), ) + generate_extraction_task_response = await app.LLM_API_HANDLER( generate_extraction_task_prompt, task_v2=task_v2, diff --git a/skyvern/utils/prompt_engine.py b/skyvern/utils/prompt_engine.py index f0038025..1f1ad701 100644 --- a/skyvern/utils/prompt_engine.py +++ b/skyvern/utils/prompt_engine.py @@ -2,11 +2,11 @@ from typing import Any import structlog +from skyvern.constants import DEFAULT_MAX_TOKENS from skyvern.forge.sdk.prompting import PromptEngine from skyvern.utils.token_counter import count_tokens from skyvern.webeye.scraper.scraper import ScrapedPage -DEFAULT_MAX_TOKENS = 100000 LOG = structlog.get_logger() @@ -14,13 +14,20 @@ def load_prompt_with_elements( scraped_page: ScrapedPage, prompt_engine: PromptEngine, template_name: str, + html_need_skyvern_attrs: bool = True, **kwargs: Any, ) -> str: - prompt = prompt_engine.load_prompt(template_name, elements=scraped_page.build_element_tree(), **kwargs) + prompt = prompt_engine.load_prompt( + template_name, + elements=scraped_page.build_element_tree(html_need_skyvern_attrs=html_need_skyvern_attrs), + **kwargs, + ) token_count = count_tokens(prompt) if token_count > DEFAULT_MAX_TOKENS: # get rid of all the secondary elements like SVG, etc - economy_elements_tree = scraped_page.build_economy_elements_tree() + economy_elements_tree = scraped_page.build_economy_elements_tree( + html_need_skyvern_attrs=html_need_skyvern_attrs + ) prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs) economy_token_count = count_tokens(prompt) LOG.warning( @@ -33,7 +40,10 @@ def load_prompt_with_elements( if economy_token_count > DEFAULT_MAX_TOKENS: # !!! HACK alert # dump the last 1/3 of the html context and keep the first 2/3 of the html context - economy_elements_tree_dumped = scraped_page.build_economy_elements_tree(percent_to_keep=2 / 3) + economy_elements_tree_dumped = scraped_page.build_economy_elements_tree( + html_need_skyvern_attrs=html_need_skyvern_attrs, + percent_to_keep=2 / 3, + ) prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree_dumped, **kwargs) token_count_after_dump = count_tokens(prompt) LOG.warning( diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index 16f6d2ff..5e9b6b4a 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -67,6 +67,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.tasks import Task from skyvern.forge.sdk.services.bitwarden import BitwardenConstants +from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.webeye.actions import actions from skyvern.webeye.actions.actions import ( Action, @@ -84,7 +85,6 @@ from skyvern.webeye.actions.actions import ( from skyvern.webeye.actions.responses import ActionAbort, ActionFailure, ActionResult, ActionSuccess from skyvern.webeye.scraper.scraper import ( CleanupElementTreeFunc, - ElementTreeFormat, IncrementalScrapePage, ScrapedPage, hash_element, @@ -751,12 +751,12 @@ async def handle_input_text_action( return [ActionSuccess()] if not await skyvern_element.is_raw_input(): - # parse the input context to help executing input action - prompt = prompt_engine.load_prompt( - "parse-input-or-select-context", + prompt = load_prompt_with_elements( + scraped_page=scraped_page, + prompt_engine=prompt_engine, + template_name="parse-input-or-select-context", element_id=action.element_id, action_reasoning=action.reasoning, - elements=dom.scraped_page.build_element_tree(ElementTreeFormat.HTML), ) json_response = await app.SECONDARY_LLM_API_HANDLER( @@ -1934,11 +1934,12 @@ async def sequentially_select_from_dropdown( Only return the last value today """ - prompt = prompt_engine.load_prompt( - "parse-input-or-select-context", + prompt = load_prompt_with_elements( + scraped_page=dom.scraped_page, + prompt_engine=prompt_engine, + template_name="parse-input-or-select-context", action_reasoning=action.reasoning, element_id=action.element_id, - elements=dom.scraped_page.build_element_tree(ElementTreeFormat.HTML), ) json_response = await app.SECONDARY_LLM_API_HANDLER( prompt=prompt, step=step, prompt_name="parse-input-or-select-context" @@ -2617,11 +2618,12 @@ async def normal_select( is_success = False locator = skyvern_element.get_locator() - prompt = prompt_engine.load_prompt( - "parse-input-or-select-context", + prompt = load_prompt_with_elements( + scraped_page=dom.scraped_page, + prompt_engine=prompt_engine, + template_name="parse-input-or-select-context", action_reasoning=action.reasoning, element_id=action.element_id, - elements=dom.scraped_page.build_element_tree(ElementTreeFormat.HTML), ) json_response = await app.SECONDARY_LLM_API_HANDLER( prompt=prompt, step=step, prompt_name="parse-input-or-select-context" @@ -2785,20 +2787,15 @@ async def extract_information_for_navigation_goal( 1. JSON representation of what the user is seeing 2. The scraped page """ - prompt_template = "extract-information" - - # TODO: we only use HTML element for now, introduce a way to switch in the future - element_tree_format = ElementTreeFormat.HTML - element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format, html_need_skyvern_attrs=False) - scraped_page_refreshed = await scraped_page.refresh() - context = ensure_context() - extract_information_prompt = prompt_engine.load_prompt( - prompt_template, + extract_information_prompt = load_prompt_with_elements( + scraped_page=scraped_page_refreshed, + prompt_engine=prompt_engine, + template_name="extract-information", + html_need_skyvern_attrs=False, navigation_goal=task.navigation_goal, navigation_payload=task.navigation_payload, - elements=element_tree_in_prompt, data_extraction_goal=task.data_extraction_goal, extracted_information_schema=task.extracted_information_schema, current_url=scraped_page_refreshed.url, diff --git a/skyvern/webeye/scraper/scraper.py b/skyvern/webeye/scraper/scraper.py index 47ce3267..262f8adb 100644 --- a/skyvern/webeye/scraper/scraper.py +++ b/skyvern/webeye/scraper/scraper.py @@ -10,10 +10,11 @@ from playwright.async_api import Frame, Locator, Page from pydantic import BaseModel, PrivateAttr from skyvern.config import settings -from skyvern.constants import BUILDING_ELEMENT_TREE_TIMEOUT_MS, SKYVERN_DIR, SKYVERN_ID_ATTR +from skyvern.constants import BUILDING_ELEMENT_TREE_TIMEOUT_MS, DEFAULT_MAX_TOKENS, SKYVERN_DIR, SKYVERN_ID_ATTR from skyvern.exceptions import FailedToTakeScreenshot, ScrapingFailed, UnknownElementTreeFormat from skyvern.forge.sdk.api.crypto import calculate_sha256 from skyvern.forge.sdk.core import skyvern_context +from skyvern.utils.token_counter import count_tokens from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.utils.page import SkyvernFrame @@ -230,6 +231,7 @@ class ScrapedPage(BaseModel): element_tree: list[dict] element_tree_trimmed: list[dict] economy_element_tree: list[dict] | None = None + last_used_element_tree: list[dict] | None = None screenshots: list[bytes] url: str html: str @@ -258,6 +260,7 @@ class ScrapedPage(BaseModel): def build_element_tree( self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True ) -> str: + self.last_used_element_tree = self.element_tree_trimmed if fmt == ElementTreeFormat.JSON: return json.dumps(self.element_tree_trimmed) @@ -291,6 +294,7 @@ class ScrapedPage(BaseModel): self.economy_element_tree = economy_elements final_element_tree = self.economy_element_tree[: int(len(self.economy_element_tree) * percent_to_keep)] + self.last_used_element_tree = final_element_tree if fmt == ElementTreeFormat.JSON: return json.dumps(final_element_tree) @@ -488,13 +492,26 @@ async def scrape_web_unsafe( LOG.info("Waiting for 5 seconds before scraping the website.") await asyncio.sleep(5) - screenshots = [] - if take_screenshots: - screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=draw_boxes) - elements, element_tree = await get_interactable_element_tree(page, scrape_exclude) element_tree = await cleanup_element_tree(page, url, copy.deepcopy(element_tree)) + element_tree_trimmed = trim_element_tree(copy.deepcopy(element_tree)) + screenshots = [] + if take_screenshots: + element_tree_trimmed_html_str = "".join( + json_to_html(element, need_skyvern_attrs=False) for element in element_tree_trimmed + ) + token_count = count_tokens(element_tree_trimmed_html_str) + max_screenshot_number = settings.MAX_NUM_SCREENSHOTS + if token_count > DEFAULT_MAX_TOKENS: + max_screenshot_number = min(max_screenshot_number, 1) + + screenshots = await SkyvernFrame.take_split_screenshots( + page=page, + url=url, + draw_boxes=draw_boxes, + max_number=max_screenshot_number, + ) id_to_css_dict, id_to_element_dict, id_to_frame_dict, id_to_element_hash, hash_to_element_ids = build_element_dict( elements ) @@ -524,7 +541,7 @@ async def scrape_web_unsafe( id_to_element_hash=id_to_element_hash, hash_to_element_ids=hash_to_element_ids, element_tree=element_tree, - element_tree_trimmed=trim_element_tree(copy.deepcopy(element_tree)), + element_tree_trimmed=element_tree_trimmed, screenshots=screenshots, url=page.url, html=html,