Add complete action verification (#845)

This commit is contained in:
Kerem Yilmaz
2024-09-17 18:59:40 -07:00
committed by GitHub
parent 20154020dd
commit d19ff2bd69
3 changed files with 115 additions and 1 deletions

View File

@@ -50,7 +50,7 @@ from skyvern.webeye.actions.actions import (
WebAction,
parse_actions,
)
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
from skyvern.webeye.actions.handler import ActionHandler, handle_complete_action, poll_verification_code
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.browser_factory import BrowserState
@@ -773,6 +773,36 @@ class ForgeAgent:
step_retry=step.retry_index,
action_results=action_results,
)
if app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
"CHECK_USER_GOAL_SUCCESS_EVERY_STEP",
task.workflow_run_id or task.task_id,
properties={
"organization_id": task.organization_id,
"organization_created_at": str(organization.created_at) if organization else None,
},
):
LOG.info("Checking if user goal is achieved after re-scraping the page")
# Check if navigation goal is achieved after re-scraping the page
new_scraped_page = await self._scrape_with_type(
task=task,
step=step,
browser_state=browser_state,
scrape_type=ScrapeType.NORMAL,
organization=organization,
)
if new_scraped_page is None:
LOG.warning("Failed to scrape the page before checking user goal success, skipping check...")
else:
working_page = await browser_state.get_working_page()
result_tuple = await self.check_user_goal_success(
page=working_page,
scraped_page=new_scraped_page,
task=task,
step=step,
)
if result_tuple is not None:
complete_action, action_results = result_tuple
detailed_agent_step_output.actions_and_results.append((complete_action, action_results))
# If no action errors return the agent state and output
completed_step = await self.update_step(
step=step,
@@ -811,6 +841,55 @@ class ForgeAgent:
)
return failed_step, detailed_agent_step_output.get_clean_detailed_output()
@staticmethod
async def check_user_goal_success(
page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> tuple[CompleteAction, list[ActionResult]] | None:
try:
# Check if Skyvern already returned a complete action, if so, don't run verification
if step.output and step.output.actions_and_results:
for action, results in step.output.actions_and_results:
if isinstance(action, CompleteAction):
return None
verification_prompt = prompt_engine.load_prompt(
"check-user-goal",
navigation_goal=task.navigation_goal,
navigation_payload=task.navigation_payload,
elements=scraped_page.build_element_tree(ElementTreeFormat.HTML),
)
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=page.url)
verification_llm_api_handler = app.SECONDARY_LLM_API_HANDLER
verification_response = await verification_llm_api_handler(
prompt=verification_prompt, step=step, screenshots=screenshots
)
if "user_goal_achieved" not in verification_response or "reasoning" not in verification_response:
LOG.error(
"Invalid LLM response for user goal success verification, skipping verification",
verification_response=verification_response,
)
return None
user_goal_achieved: bool = verification_response["user_goal_achieved"]
complete_action = CompleteAction(
reasoning=verification_response["reasoning"],
data_extraction_goal=task.data_extraction_goal,
)
# We don't want to return a complete action if the user goal is not achieved since we're checking at every step
if not user_goal_achieved:
return None
LOG.info("User goal achieved, executing complete action")
action_results = await handle_complete_action(complete_action, page, scraped_page, task, step)
return complete_action, action_results
except Exception:
LOG.error("LLM verification failed for complete action, skipping LLM verification", exc_info=True)
return None
async def record_artifacts_after_action(self, task: Task, step: Step, browser_state: BrowserState) -> None:
working_page = await browser_state.get_working_page()
if not working_page: