From defd761e58c6998a91866f0aacb48380daf8bd2c Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 11 Feb 2025 14:47:41 +0800 Subject: [PATCH] add functionality to cache task_run (#1755) --- ...1f0f795bd_add_task_run_org_run_id_index.py | 29 +++++++++++ skyvern/forge/agent.py | 8 +-- .../prompts/skyvern/single-click-action.j2 | 9 +++- skyvern/forge/sdk/db/client.py | 27 ++++++++++ skyvern/forge/sdk/db/models.py | 5 +- skyvern/webeye/actions/caching.py | 51 +++++++++++++++---- skyvern/webeye/scraper/scraper.py | 16 +++++- 7 files changed, 127 insertions(+), 18 deletions(-) create mode 100644 alembic/versions/2025_02_11_0641-b111f0f795bd_add_task_run_org_run_id_index.py diff --git a/alembic/versions/2025_02_11_0641-b111f0f795bd_add_task_run_org_run_id_index.py b/alembic/versions/2025_02_11_0641-b111f0f795bd_add_task_run_org_run_id_index.py new file mode 100644 index 00000000..939b755d --- /dev/null +++ b/alembic/versions/2025_02_11_0641-b111f0f795bd_add_task_run_org_run_id_index.py @@ -0,0 +1,29 @@ +"""add task_run_org_run_id_index + +Revision ID: b111f0f795bd +Revises: 60d0743274c9 +Create Date: 2025-02-11 06:41:35.336836+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b111f0f795bd" +down_revision: Union[str, None] = "60d0743274c9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index("task_run_org_run_id_index", "task_runs", ["organization_id", "run_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("task_run_org_run_id_index", table_name="task_runs") + # ### end Alembic commands ### diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index c9ef1ff4..6dd0a3b1 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -348,7 +348,7 @@ class ForgeAgent: step, browser_state, detailed_output, - ) = await self._initialize_execution_state(task, step, workflow_run, browser_session_id) + ) = await self.initialize_execution_state(task, step, workflow_run, browser_session_id) if ( not task.navigation_goal @@ -759,7 +759,7 @@ class ForgeAgent: ( scraped_page, extract_action_prompt, - ) = await self._build_and_record_step_prompt( + ) = await self.build_and_record_step_prompt( task, step, browser_state, @@ -1245,7 +1245,7 @@ class ForgeAgent: exc_info=True, ) - async def _initialize_execution_state( + async def initialize_execution_state( self, task: Task, step: Step, @@ -1322,7 +1322,7 @@ class ForgeAgent: scrape_exclude=app.scrape_exclude, ) - async def _build_and_record_step_prompt( + async def build_and_record_step_prompt( self, task: Task, step: Step, diff --git a/skyvern/forge/prompts/skyvern/single-click-action.j2 b/skyvern/forge/prompts/skyvern/single-click-action.j2 index f9c0e18d..55585377 100644 --- a/skyvern/forge/prompts/skyvern/single-click-action.j2 +++ b/skyvern/forge/prompts/skyvern/single-click-action.j2 @@ -12,7 +12,7 @@ Reply in JSON format with the following keys: "user_detail_query": str, // Think of this value as a Jeopardy question. Ask the user for the details you need for executing this action. Ask the question even if the details are disclosed in user instruction or user details. If you are clicking on something specific, ask about what to click on. If you're downloading a file and you have multiple options, ask the user which one to download. Otherwise, use null. Examples are: "What is the previous insurance provider of the user?", "Which invoice should I download?", "Does the user have any pets?". If the action doesn't require any user details, use null. "user_detail_answer": str, // The answer to the `user_detail_query`. The source of this answer can be user instruction or user details. "confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence - "action_type": str, // It's a string enum: "CLICK". "CLICK" is an element you'd like to click. + "action_type": str, // It's a string enum: "CLICK". "CLICK" type means there's an element you'd like to click. "id": str, // The id of the element to take action on. The id has to be one from the elements list. "download": bool, // If true, the browser will trigger a download by clicking the element. If false, the browser will click the element without triggering a download. }] @@ -25,7 +25,7 @@ HTML elements from `{{ current_url }}`: {{ elements }} ``` -User instruction: +User instruction (user's intention or self questioning to help figure out what to click): ``` {{ navigation_goal }} ``` @@ -33,7 +33,12 @@ User instruction: User details: ``` {{ navigation_payload_str }} +```{% if user_context %} + +Context of the big goal user wants to achieve: ``` +{{ user_context }} +```{% endif %} Current datetime, ISO format: ``` diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 4368be7c..5ad79a87 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -2672,3 +2672,30 @@ class AgentDB: await session.commit() await session.refresh(task_run) return TaskRun.model_validate(task_run) + + async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> TaskRun: + async with self.Session() as session: + task_run = await session.scalars( + select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id) + ).first() + if task_run: + task_run.cached = True + await session.commit() + await session.refresh(task_run) + return TaskRun.model_validate(task_run) + raise NotFoundError(f"TaskRun {run_id} not found") + + async def get_cached_task_run( + self, task_run_type: TaskRunType, url_hash: str | None = None, organization_id: str | None = None + ) -> TaskRun | None: + async with self.Session() as session: + query = select(TaskRunModel) + if task_run_type: + query = query.filter_by(task_run_type=task_run_type) + if url_hash: + query = query.filter_by(url_hash=url_hash) + if organization_id: + query = query.filter_by(organization_id=organization_id) + query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc()) + task_run = await session.scalars(query).first() + return TaskRun.model_validate(task_run) if task_run else None diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 74dda787..434bd1a1 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -614,7 +614,10 @@ class PersistentBrowserSessionModel(Base): class TaskRunModel(Base): __tablename__ = "task_runs" - __table_args__ = (Index("task_run_org_url_index", "organization_id", "url_hash", "cached"),) + __table_args__ = ( + Index("task_run_org_url_index", "organization_id", "url_hash", "cached"), + Index("task_run_org_run_id_index", "organization_id", "run_id"), + ) task_run_id = Column(String, primary_key=True, default=generate_task_run_id) organization_id = Column(String, nullable=False) diff --git a/skyvern/webeye/actions/caching.py b/skyvern/webeye/actions/caching.py index e712873f..43381a52 100644 --- a/skyvern/webeye/actions/caching.py +++ b/skyvern/webeye/actions/caching.py @@ -108,7 +108,7 @@ async def _retrieve_action_plan(task: Task, step: Step, scraped_page: ScrapedPag LOG.info("Found cached actions to execute", actions=cached_actions_to_execute) - actions_queries: list[tuple[Action, str | None]] = [] + actions_queries: list[Action] = [] for idx, cached_action in enumerate(cached_actions_to_execute): updated_action = cached_action.model_copy() updated_action.status = ActionStatus.pending @@ -135,7 +135,7 @@ async def _retrieve_action_plan(task: Task, step: Step, scraped_page: ScrapedPag "All elements with either no hash or multiple hashes should have been already filtered out" ) - actions_queries.append((updated_action, updated_action.intention)) + actions_queries.append(updated_action) # Check for unsupported actions before personalizing the actions # Classify the supported actions into two groups: @@ -155,10 +155,12 @@ async def _retrieve_action_plan(task: Task, step: Step, scraped_page: ScrapedPag async def personalize_actions( task: Task, step: Step, - actions_queries: list[tuple[Action, str | None]], + actions_queries: list[Action], scraped_page: ScrapedPage, ) -> list[Action]: - queries_and_answers: dict[str, str | None] = {query: None for _, query in actions_queries if query} + queries_and_answers: dict[str, str | None] = { + action.intention: None for action in actions_queries if action.intention + } answered_queries: dict[str, str] = {} if queries_and_answers: @@ -168,9 +170,13 @@ async def personalize_actions( ) personalized_actions = [] - for action, query in actions_queries: + for action in actions_queries: + query = action.intention if query and (personalized_answer := answered_queries.get(query)): - personalized_actions.append(personalize_action(action, query, personalized_answer)) + current_personized_actions = await personalize_action( + action, query, personalized_answer, task, step, scraped_page + ) + personalized_actions.extend(current_personized_actions) else: personalized_actions.append(action) @@ -198,24 +204,49 @@ async def get_user_detail_answers( raise e -def personalize_action(action: Action, query: str, answer: str) -> Action: +async def personalize_action( + action: Action, + query: str, + answer: str, + task: Task, + step: Step, + scraped_page: ScrapedPage, +) -> list[Action]: action.intention = query action.response = answer if action.action_type == ActionType.INPUT_TEXT: action.text = answer + elif action.action_type == ActionType.UPLOAD_FILE: + action.file_url = answer + elif action.action_type == ActionType.CLICK: + # TODO: we only use cached action.intention. send the intention, navigation payload + navigation goal, html + # to small llm and make a decision of which elements to click. Not clicking anything is also an option here + return [action] + elif action.action_type == ActionType.SELECT_OPTION: + # TODO: send the selection action with the original/previous option value. Our current selection agent + # is already able to handle it + return [action] + elif action.action_type in [ + ActionType.COMPLETE, + ActionType.WAIT, + ActionType.TERMINATE, + ActionType.SOLVE_CAPTCHA, + ]: + return [action] else: raise CachedActionPlanError( f"Unsupported action type for personalization, fallback to no-cache mode: {action.action_type}" ) - return action + return [action] -def check_for_unsupported_actions(actions_queries: list[tuple[Action, str | None]]) -> None: +def check_for_unsupported_actions(actions_queries: list[Action]) -> None: supported_actions = [ActionType.INPUT_TEXT, ActionType.WAIT, ActionType.CLICK, ActionType.COMPLETE] supported_actions_with_query = [ActionType.INPUT_TEXT] - for action, query in actions_queries: + for action in actions_queries: + query = action.intention if action.action_type not in supported_actions: raise CachedActionPlanError( f"This action type does not support caching: {action.action_type}, fallback to no-cache mode" diff --git a/skyvern/webeye/scraper/scraper.py b/skyvern/webeye/scraper/scraper.py index 922c67a8..0b81d23c 100644 --- a/skyvern/webeye/scraper/scraper.py +++ b/skyvern/webeye/scraper/scraper.py @@ -282,6 +282,15 @@ class ScrapedPage(BaseModel): self.url = refreshed_page.url return self + async def generate_scraped_page_without_screenshots(self) -> Self: + return await scrape_website( + browser_state=self._browser_state, + url=self.url, + cleanup_element_tree=self._clean_up_func, + scrape_exclude=self._scrape_exclude, + take_screenshots=False, + ) + async def scrape_website( browser_state: BrowserState, @@ -289,6 +298,7 @@ async def scrape_website( cleanup_element_tree: CleanupElementTreeFunc, num_retry: int = 0, scrape_exclude: ScrapeExcludeFunc | None = None, + take_screenshots: bool = True, ) -> ScrapedPage: """ ************************************************************************************************ @@ -318,6 +328,7 @@ async def scrape_website( url=url, cleanup_element_tree=cleanup_element_tree, scrape_exclude=scrape_exclude, + take_screenshots=take_screenshots, ) except Exception as e: # NOTE: MAX_SCRAPING_RETRIES is set to 0 in both staging and production @@ -386,6 +397,7 @@ async def scrape_web_unsafe( url: str, cleanup_element_tree: CleanupElementTreeFunc, scrape_exclude: ScrapeExcludeFunc | None = None, + take_screenshots: bool = True, ) -> ScrapedPage: """ Asynchronous function that performs web scraping without any built-in error handling. This function is intended @@ -410,7 +422,9 @@ async def scrape_web_unsafe( LOG.info("Waiting for 5 seconds before scraping the website.") await asyncio.sleep(5) - screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=True) + screenshots = [] + if take_screenshots: + screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=True) elements, element_tree = await get_interactable_element_tree(page, scrape_exclude) element_tree = await cleanup_element_tree(page, url, copy.deepcopy(element_tree))