verification code V2 - support verification code of multiple separate single character input fields (#683)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-08-08 02:17:15 +03:00
committed by GitHub
parent 78adb8b276
commit c872b1e4a8
5 changed files with 94 additions and 50 deletions

View File

@@ -11,12 +11,7 @@ from playwright._impl._errors import TargetClosedError
from playwright.async_api import Page
from skyvern import analytics
from skyvern.constants import (
SCRAPE_TYPE_ORDER,
SPECIAL_FIELD_VERIFICATION_CODE,
VERIFICATION_CODE_PLACEHOLDER,
ScrapeType,
)
from skyvern.constants import SCRAPE_TYPE_ORDER, SPECIAL_FIELD_VERIFICATION_CODE, ScrapeType
from skyvern.exceptions import (
BrowserStateMissingPage,
EmptyScrapePage,
@@ -53,7 +48,7 @@ from skyvern.webeye.actions.actions import (
WebAction,
parse_actions,
)
from skyvern.webeye.actions.handler import ActionHandler
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.browser_factory import BrowserState
@@ -548,6 +543,13 @@ class ForgeAgent:
step=step,
screenshots=scraped_page.screenshots,
)
json_response = await self.handle_potential_verification_code(
task,
step,
scraped_page,
browser_state,
json_response,
)
detailed_agent_step_output.llm_response = json_response
actions = parse_actions(task, json_response["actions"])
@@ -951,16 +953,6 @@ class ForgeAgent:
num_elements=len(scraped_page.elements),
url=task.url,
)
actions_and_results_str = await self._get_action_results(task)
# Generate the extract action prompt
navigation_goal = task.navigation_goal
starting_url = task.url
current_url = (
await browser_state.page.evaluate("() => document.location.href") if browser_state.page else starting_url
)
# TODO: we only use HTML element for now, introduce a way to switch in the future
element_tree_format = ElementTreeFormat.HTML
LOG.info(
@@ -971,18 +963,12 @@ class ForgeAgent:
)
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
final_navigation_payload = self._build_navigation_payload(task)
extract_action_prompt = prompt_engine.load_prompt(
"extract-action",
navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(final_navigation_payload),
starting_url=starting_url,
current_url=current_url,
elements=element_tree_in_prompt,
data_extraction_goal=task.data_extraction_goal,
action_history=actions_and_results_str,
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
utc_datetime=datetime.utcnow().strftime("%Y-%m-%d %H:%M"),
extract_action_prompt = await self._build_extract_action_prompt(
task,
browser_state,
element_tree_in_prompt,
verification_code_check=bool(task.totp_verification_url),
expire_verification_code=True,
)
await app.ARTIFACT_MANAGER.create_artifact(
@@ -1013,26 +999,62 @@ class ForgeAgent:
return scraped_page, extract_action_prompt
async def _build_extract_action_prompt(
self,
task: Task,
browser_state: BrowserState,
element_tree_in_prompt: str,
verification_code_check: bool = False,
expire_verification_code: bool = False,
) -> str:
actions_and_results_str = await self._get_action_results(task)
# Generate the extract action prompt
navigation_goal = task.navigation_goal
starting_url = task.url
current_url = (
await browser_state.page.evaluate("() => document.location.href") if browser_state.page else starting_url
)
final_navigation_payload = self._build_navigation_payload(
task, expire_verification_code=expire_verification_code
)
return prompt_engine.load_prompt(
"extract-action",
navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(final_navigation_payload),
starting_url=starting_url,
current_url=current_url,
elements=element_tree_in_prompt,
data_extraction_goal=task.data_extraction_goal,
action_history=actions_and_results_str,
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
utc_datetime=datetime.utcnow().strftime("%Y-%m-%d %H:%M"),
verification_code_check=verification_code_check,
)
def _build_navigation_payload(
self,
task: Task,
expire_verification_code: bool = False,
) -> dict[str, Any] | list | str | None:
final_navigation_payload = task.navigation_payload
if task.totp_verification_url:
current_context = skyvern_context.ensure_context()
verification_code = current_context.totp_codes.get(task.task_id)
if task.totp_verification_url and verification_code:
if (
isinstance(final_navigation_payload, dict)
and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload
):
final_navigation_payload[SPECIAL_FIELD_VERIFICATION_CODE] = VERIFICATION_CODE_PLACEHOLDER
final_navigation_payload[SPECIAL_FIELD_VERIFICATION_CODE] = verification_code
elif (
isinstance(final_navigation_payload, str)
and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload
):
final_navigation_payload = (
final_navigation_payload
+ "\n"
+ str({SPECIAL_FIELD_VERIFICATION_CODE: VERIFICATION_CODE_PLACEHOLDER})
final_navigation_payload + "\n" + str({SPECIAL_FIELD_VERIFICATION_CODE: verification_code})
)
if expire_verification_code:
current_context.totp_codes.pop(task.task_id)
return final_navigation_payload
async def _get_action_results(self, task: Task) -> str:
@@ -1552,6 +1574,40 @@ class ForgeAgent:
)
return None, None, next_step
async def handle_potential_verification_code(
self,
task: Task,
step: Step,
scraped_page: ScrapedPage,
browser_state: BrowserState,
json_response: dict[str, Any],
) -> dict[str, Any]:
# TODO: handle verifications and resend the request if needed
# parse the "need_verification_code" field from the response
need_verification_code = json_response.get("need_verification_code")
if need_verification_code and task.totp_verification_url and task.organization_id:
LOG.info("Need verification code", step_id=step.step_id)
verification_code = await poll_verification_code(
task.task_id, task.organization_id, url=task.totp_verification_url
)
current_context = skyvern_context.ensure_context()
current_context.totp_codes[task.task_id] = verification_code
element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML)
extract_action_prompt = await self._build_extract_action_prompt(
task,
browser_state,
element_tree_in_prompt,
verification_code_check=False,
expire_verification_code=False,
)
return await app.LLM_API_HANDLER(
prompt=extract_action_prompt,
step=step,
screenshots=scraped_page.screenshots,
)
return json_response
@staticmethod
async def get_task_errors(task: Task) -> list[UserDefinedError]:
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)