Pedro/support_multi_field_6_digit_totp (#3622)

This commit is contained in:
pedrohsdb
2025-10-06 16:37:35 -07:00
committed by GitHub
parent 7f6e5d2e36
commit 6fc56d9775
4 changed files with 415 additions and 8 deletions

View File

@@ -45,3 +45,8 @@ class GetTOTPVerificationCodeError(SkyvernDefinedError):
class TimeoutGetTOTPVerificationCodeError(SkyvernDefinedError): class TimeoutGetTOTPVerificationCodeError(SkyvernDefinedError):
error_code: str = "OTP_TIMEOUT" error_code: str = "OTP_TIMEOUT"
reasoning: str = "Timeout getting TOTP verification code." reasoning: str = "Timeout getting TOTP verification code."
class TOTPExpiredError(SkyvernDefinedError):
error_code: str = "OTP_EXPIRED"
reasoning: str = "TOTP verification code has expired during multi-field input sequence."

View File

@@ -1202,14 +1202,47 @@ class ForgeAgent:
# Do not verify the complete action when complete_verification is False # Do not verify the complete action when complete_verification is False
# set verified to True will skip the completion verification # set verified to True will skip the completion verification
action.verified = True action.verified = True
# Pass TOTP secret to handler for multi-field TOTP sequences
# Handler will generate TOTP at execution time
if (
action.action_type == ActionType.INPUT_TEXT
and self._is_multi_field_totp_sequence(actions)
and (totp_secret := skyvern_context.ensure_context().totp_codes.get(f"{task.task_id}_secret"))
):
# Pass TOTP secret to handler for execution-time generation
action.totp_timing_info = {
"is_totp_sequence": True,
"action_index": action_idx,
"totp_secret": totp_secret,
"is_retry": step.retry_index > 0,
}
results = await ActionHandler.handle_action(scraped_page, task, step, current_page, action) results = await ActionHandler.handle_action(scraped_page, task, step, current_page, action)
await app.AGENT_FUNCTION.post_action_execution() await app.AGENT_FUNCTION.post_action_execution()
detailed_agent_step_output.actions_and_results[action_idx] = ( detailed_agent_step_output.actions_and_results[action_idx] = (
action, action,
results, results,
) )
# wait random time between actions to avoid detection
await asyncio.sleep(random.uniform(0.5, 1.0)) # Determine wait time between actions
wait_time = random.uniform(0.5, 1.0)
# For multi-field TOTP sequences, use zero delay between all digits for fast execution
if action.action_type == ActionType.INPUT_TEXT and self._is_multi_field_totp_sequence(actions):
current_text = action.text if hasattr(action, "text") else None
if current_text and len(current_text) == 1 and current_text.isdigit():
# Zero delay between all TOTP digits for fast execution
wait_time = 0.0
LOG.debug(
"TOTP: zero delay for digit",
task_id=task.task_id,
action_idx=action_idx,
digit=current_text,
)
await asyncio.sleep(wait_time)
await self.record_artifacts_after_action(task, step, browser_state, engine) await self.record_artifacts_after_action(task, step, browser_state, engine)
for result in results: for result in results:
result.step_retry_number = step.retry_index result.step_retry_number = step.retry_index
@@ -1306,6 +1339,21 @@ class ForgeAgent:
action_results=action_results, action_results=action_results,
) )
# Clean up TOTP cache after multi-field TOTP sequence completion
if self._is_multi_field_totp_sequence(actions):
context = skyvern_context.ensure_context()
cache_key = f"{task.task_id}_totp_cache"
if cache_key in context.totp_codes:
context.totp_codes.pop(cache_key)
LOG.debug(
"Cleaned up TOTP cache after multi-field sequence completion",
task_id=task.task_id,
)
secret_key = f"{task.task_id}_secret"
if secret_key in context.totp_codes:
context.totp_codes.pop(secret_key)
# Check if Skyvern already returned a complete action, if so, don't run user goal check # Check if Skyvern already returned a complete action, if so, don't run user goal check
has_decisive_action = False has_decisive_action = False
if detailed_agent_step_output and detailed_agent_step_output.actions_and_results: if detailed_agent_step_output and detailed_agent_step_output.actions_and_results:
@@ -2065,7 +2113,7 @@ class ForgeAgent:
await SkyvernFrame.evaluate(frame=page, expression="() => document.location.href") if page else starting_url await SkyvernFrame.evaluate(frame=page, expression="() => document.location.href") if page else starting_url
) )
final_navigation_payload = self._build_navigation_payload( final_navigation_payload = self._build_navigation_payload(
task, expire_verification_code=expire_verification_code task, expire_verification_code=expire_verification_code, step=step, scraped_page=scraped_page
) )
task_type = task.task_type if task.task_type else TaskType.general task_type = task.task_type if task.task_type else TaskType.general
@@ -2167,12 +2215,124 @@ class ForgeAgent:
return full_prompt, use_caching return full_prompt, use_caching
def _should_process_totp(self, scraped_page: ScrapedPage | None) -> bool:
"""Detect TOTP pages by checking for multiple input fields or verification keywords."""
if not scraped_page:
return False
try:
# Count input fields that could be for TOTP (more flexible than maxlength="1")
input_fields = [
element
for element in scraped_page.elements
if element.get("tagName", "").lower() == "input"
and element.get("attributes", {}).get("type", "text").lower() in ["text", "number", "tel"]
]
# Check for multiple input fields (potential multi-field TOTP)
if len(input_fields) >= 4:
# Additional check: look for patterns that suggest multi-field TOTP
# Check if inputs are close together or have similar attributes
has_maxlength_1 = any(elem.get("attributes", {}).get("maxlength") == "1" for elem in input_fields)
# Check for input fields with numeric patterns (type="number", pattern for digits)
has_numeric_patterns = any(
elem.get("attributes", {}).get("type") == "number"
or elem.get("attributes", {}).get("pattern", "").isdigit()
or "digit" in elem.get("attributes", {}).get("pattern", "").lower()
for elem in input_fields
)
if has_maxlength_1 or has_numeric_patterns:
return True
# Check for TOTP-related keywords in page content
page_text = scraped_page.html.lower() if scraped_page.html else ""
totp_keywords = [
"verification code",
"authentication code",
"security code",
"2fa",
"two-factor",
"totp",
"authenticator",
"verification",
"enter code",
"verification number",
"security number",
]
keyword_matches = sum(1 for keyword in totp_keywords if keyword in page_text)
# If we have multiple TOTP keywords and multiple input fields, likely TOTP
if keyword_matches >= 2 and len(input_fields) >= 6:
return True
# Strong single keyword match with multiple inputs
strong_keywords = ["verification code", "authentication code", "2fa", "two-factor"]
if any(keyword in page_text for keyword in strong_keywords) and len(input_fields) >= 3:
return True
return False
except Exception:
return False
def _is_multi_field_totp_sequence(self, actions: list) -> bool:
"""
Check if the action sequence represents a multi-field TOTP input (6 single-digit fields).
Args:
actions: List of actions to analyze
Returns:
bool: True if this is a multi-field TOTP sequence
"""
# Must have at least 4 actions (minimum for TOTP)
if len(actions) < 4:
return False
# Check if we have multiple consecutive single-digit INPUT_TEXT actions
consecutive_single_digits = 0
max_consecutive = 0
for action in actions:
if (
action.action_type == ActionType.INPUT_TEXT
and hasattr(action, "text")
and action.text
and len(action.text) == 1
and action.text.isdigit()
):
consecutive_single_digits += 1
max_consecutive = max(max_consecutive, consecutive_single_digits)
else:
# If we hit a non-single-digit action, reset consecutive counter
consecutive_single_digits = 0
# Consider it a multi-field TOTP if we have 4+ consecutive single-digit inputs
# This is more reliable than just counting total single digits
# We use 4+ as the threshold to avoid false positives with single TOTP fields
is_multi_field_totp = max_consecutive >= 4
if is_multi_field_totp:
LOG.debug(
"Detected multi-field TOTP sequence",
max_consecutive=max_consecutive,
total_actions=len(actions),
)
return is_multi_field_totp
def _build_navigation_payload( def _build_navigation_payload(
self, self,
task: Task, task: Task,
expire_verification_code: bool = False, expire_verification_code: bool = False,
step: Step | None = None,
scraped_page: ScrapedPage | None = None,
) -> dict[str, Any] | list | str | None: ) -> dict[str, Any] | list | str | None:
final_navigation_payload = task.navigation_payload final_navigation_payload = task.navigation_payload
current_context = skyvern_context.ensure_context() current_context = skyvern_context.ensure_context()
verification_code = current_context.totp_codes.get(task.task_id) verification_code = current_context.totp_codes.get(task.task_id)
if (task.totp_verification_url or task.totp_identifier) and verification_code: if (task.totp_verification_url or task.totp_identifier) and verification_code:
@@ -2190,6 +2350,32 @@ class ForgeAgent:
) )
if expire_verification_code: if expire_verification_code:
current_context.totp_codes.pop(task.task_id) current_context.totp_codes.pop(task.task_id)
# Store TOTP secrets and provide placeholder TOTP for LLM to see format
# Only when on a TOTP page to avoid premature processing
if (
task.workflow_run_id
and step
and isinstance(final_navigation_payload, dict)
and self._should_process_totp(scraped_page)
):
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(task.workflow_run_id)
for key, value in list(final_navigation_payload.items()):
if isinstance(value, dict) and "totp" in value:
totp_placeholder = value.get("totp")
if totp_placeholder and isinstance(totp_placeholder, str):
totp_secret_key = workflow_run_context.totp_secret_value_key(totp_placeholder)
totp_secret = workflow_run_context.get_original_secret_value_or_none(totp_secret_key)
if totp_secret:
# Store TOTP secret for handler to use during execution
current_context = skyvern_context.ensure_context()
current_context.totp_codes[f"{task.task_id}_secret"] = totp_secret
# Send a placeholder TOTP for the LLM to see the format
final_navigation_payload[key]["totp"] = "123456"
return final_navigation_payload return final_navigation_payload
async def _get_action_results(self, task: Task, current_step: Step | None = None) -> str: async def _get_action_results(self, task: Task, current_step: Step | None = None) -> str:

View File

@@ -88,6 +88,9 @@ class Action(BaseModel):
is_checked: bool | None = None is_checked: bool | None = None
verified: bool = False verified: bool = False
# TOTP timing information for multi-field TOTP sequences
totp_timing_info: dict[str, Any] | None = None
created_at: datetime | None = None created_at: datetime | None = None
modified_at: datetime | None = None modified_at: datetime | None = None
created_by: str | None = None created_by: str | None = None

View File

@@ -2,6 +2,7 @@ import asyncio
import copy import copy
import json import json
import os import os
import time
import urllib.parse import urllib.parse
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@@ -10,6 +11,7 @@ from typing import Any, Awaitable, Callable, List
import pyotp import pyotp
import structlog import structlog
from playwright._impl._errors import Error as PlaywrightError
from playwright.async_api import FileChooser, Frame, Locator, Page, TimeoutError from playwright.async_api import FileChooser, Frame, Locator, Page, TimeoutError
from pydantic import BaseModel from pydantic import BaseModel
@@ -22,6 +24,7 @@ from skyvern.constants import (
REPO_ROOT_DIR, REPO_ROOT_DIR,
SKYVERN_ID_ATTR, SKYVERN_ID_ATTR,
) )
from skyvern.errors.errors import TOTPExpiredError
from skyvern.exceptions import ( from skyvern.exceptions import (
DownloadFileMaxWaitingTime, DownloadFileMaxWaitingTime,
EmptySelect, EmptySelect,
@@ -929,6 +932,128 @@ async def handle_click_to_download_file_action(
return [ActionSuccess(download_triggered=True)] return [ActionSuccess(download_triggered=True)]
# TOTP timing constants
TOTP_TIME_STEP_SECONDS = 30
TOTP_EXPIRY_THRESHOLD_SECONDS = 20
async def _handle_multi_field_totp_sequence(
timing_info: dict[str, Any],
task: Task,
) -> list[ActionResult] | None:
"""
Handle TOTP generation and caching for multi-field TOTP sequences.
Returns:
ActionFailure if TOTP handling failed, None if successful
"""
action_index = timing_info["action_index"]
cache_key = f"{task.task_id}_totp_cache"
current_context = skyvern_context.ensure_context()
if action_index == 0:
# First digit: generate TOTP and cache it
totp_secret = timing_info["totp_secret"]
totp = pyotp.TOTP(totp_secret)
# Check current TOTP expiry time
current_time = int(time.time())
current_totp_valid_until = ((current_time // TOTP_TIME_STEP_SECONDS) + 1) * TOTP_TIME_STEP_SECONDS
seconds_until_expiry = current_totp_valid_until - current_time
# If less than threshold seconds until expiry, use the next TOTP
if seconds_until_expiry < TOTP_EXPIRY_THRESHOLD_SECONDS:
# Force generation of next TOTP by advancing time
next_time = current_totp_valid_until
current_totp = totp.at(next_time)
LOG.debug(
"Using multi-field TOTP flow - using NEXT TOTP due to <20s expiry",
task_id=task.task_id,
action_idx=action_index,
current_totp=totp.now(),
next_totp=current_totp,
seconds_until_expiry=seconds_until_expiry,
is_retry=timing_info.get("is_retry", False),
)
else:
# Use current TOTP
current_totp = totp.now()
current_context.totp_codes[cache_key] = current_totp
else:
# Subsequent digits: reuse cached TOTP
current_totp = current_context.totp_codes.get(cache_key)
if not current_totp:
# TOTP cache missing for subsequent digit - this should not happen
# If it does, something went wrong with the first digit, so fail the action
LOG.error(
"TOTP cache missing for subsequent digit - first digit may have failed",
task_id=task.task_id,
action_idx=action_index,
cache_key=cache_key,
)
return [ActionFailure(TOTPExpiredError())]
# Check if cached TOTP has expired
totp_secret = timing_info["totp_secret"]
totp = pyotp.TOTP(totp_secret)
# Get current time and calculate TOTP expiry
current_time = int(time.time())
totp_valid_until = ((current_time // TOTP_TIME_STEP_SECONDS) + 1) * TOTP_TIME_STEP_SECONDS
if current_time >= totp_valid_until:
LOG.error(
"Cached TOTP has expired during multi-field sequence",
task_id=task.task_id,
action_idx=action_index,
current_time=current_time,
totp_valid_until=totp_valid_until,
cached_totp=current_totp,
)
return [ActionFailure(TOTPExpiredError())]
LOG.debug(
"Using multi-field TOTP flow - reusing cached TOTP",
task_id=task.task_id,
action_idx=action_index,
totp=current_totp,
current_time=current_time,
totp_valid_until=totp_valid_until,
)
# Special handling for the 6th digit (action_index=5): wait if TOTP is not yet valid
if action_index == 5:
# Calculate when this TOTP becomes valid (valid_from time)
# If we used the next TOTP window, valid_from is the start of that window
totp_valid_from = totp_valid_until - TOTP_TIME_STEP_SECONDS
if current_time < totp_valid_from:
# TOTP is not yet valid, wait until it becomes valid
wait_seconds = totp_valid_from - current_time
LOG.debug(
"6th digit: TOTP not yet valid, waiting until valid_from",
task_id=task.task_id,
action_idx=action_index,
current_time=current_time,
totp_valid_from=totp_valid_from,
wait_seconds=wait_seconds,
totp=current_totp,
)
await asyncio.sleep(wait_seconds)
LOG.debug(
"6th digit: Finished waiting, TOTP is now valid",
task_id=task.task_id,
action_idx=action_index,
)
return None # Success
@TraceManager.traced_async(ignore_inputs=["scraped_page", "page"]) @TraceManager.traced_async(ignore_inputs=["scraped_page", "page"])
async def handle_input_text_action( async def handle_input_text_action(
action: actions.InputTextAction, action: actions.InputTextAction,
@@ -954,9 +1079,17 @@ async def handle_input_text_action(
# before filling text, we need to validate if the element can be filled if it's not one of COMMON_INPUT_TAGS # before filling text, we need to validate if the element can be filled if it's not one of COMMON_INPUT_TAGS
tag_name = scraped_page.id_to_element_dict[action.element_id]["tagName"].lower() tag_name = scraped_page.id_to_element_dict[action.element_id]["tagName"].lower()
text: str | None = await get_actual_value_of_parameter_if_secret(task, action.text)
if text is None: # Check if this is multi-field TOTP first - if so, skip secret resolution
return [ActionFailure(FailedToFetchSecret())] if action.totp_timing_info and action.totp_timing_info.get("is_totp_sequence"):
# For multi-field TOTP, we'll set text directly in the TOTP logic below
text: str = ""
else:
# For regular inputs, resolve secrets
text_result = await get_actual_value_of_parameter_if_secret(task, action.text)
if text_result is None:
return [ActionFailure(FailedToFetchSecret())]
text = text_result
is_totp_value = ( is_totp_value = (
text == BitwardenConstants.TOTP or text == OnePasswordConstants.TOTP or text == AzureVaultConstants.TOTP text == BitwardenConstants.TOTP or text == OnePasswordConstants.TOTP or text == AzureVaultConstants.TOTP
@@ -1198,6 +1331,32 @@ async def handle_input_text_action(
await skyvern_element.input(text) await skyvern_element.input(text)
return [ActionSuccess()] return [ActionSuccess()]
# Handle TOTP generation for multi-field TOTP sequences
if action.totp_timing_info:
timing_info = action.totp_timing_info
if timing_info.get("is_totp_sequence"):
result = await _handle_multi_field_totp_sequence(timing_info, task)
if result is not None:
return result # Return ActionFailure if TOTP handling failed
# Extract the digit for this action index
current_totp = skyvern_context.ensure_context().totp_codes.get(f"{task.task_id}_totp_cache")
action_index = timing_info["action_index"]
if current_totp and len(current_totp) > action_index:
digit = current_totp[action_index]
action.text = digit
# Also update the text variable that will be used later
text = digit
else:
LOG.error(
"TOTP too short for action index",
task_id=task.task_id,
action_idx=action_index,
totp_length=len(current_totp) if current_totp else 0,
)
return [ActionFailure(TOTPExpiredError())]
try: try:
# TODO: not sure if this case will trigger auto-completion # TODO: not sure if this case will trigger auto-completion
if tag_name not in COMMON_INPUT_TAGS: if tag_name not in COMMON_INPUT_TAGS:
@@ -1246,18 +1405,57 @@ async def handle_input_text_action(
try: try:
await skyvern_element.input_sequentially(text=text) await skyvern_element.input_sequentially(text=text)
finally:
incremental_element = await incremental_scraped.get_incremental_element_tree( incremental_element = await incremental_scraped.get_incremental_element_tree(
clean_and_remove_element_tree_factory( clean_and_remove_element_tree_factory(
task=task, step=step, check_filter_funcs=[check_existed_but_not_option_element_in_dom_factory(dom)] task=task,
step=step,
check_filter_funcs=[check_existed_but_not_option_element_in_dom_factory(dom)],
), ),
) )
if len(incremental_element) > 0: if len(incremental_element) > 0:
auto_complete_hacky_flag = True auto_complete_hacky_flag = True
except PlaywrightError as inc_error:
# Handle Playwright-specific errors during incremental element processing (e.g., TOTP form auto-submit)
error_message = str(inc_error).lower()
if (
"execution context was destroyed" in error_message
or "navigation" in error_message
or "target closed" in error_message
):
# These are expected during page navigation/auto-submit, silently continue
LOG.debug(
"Playwright error during incremental element processing (likely page navigation)",
task_id=task.task_id,
step_id=step.step_id,
error_type=type(inc_error).__name__,
error_message=error_message,
)
else:
LOG.warning(
"Unexpected Playwright error during incremental element processing",
task_id=task.task_id,
step_id=step.step_id,
error_type=type(inc_error).__name__,
error_message=str(inc_error),
)
except Exception as inc_error:
# Handle any other unexpected errors during incremental element processing
LOG.warning(
"Unexpected error during incremental element processing",
task_id=task.task_id,
step_id=step.step_id,
error_type=type(inc_error).__name__,
error_message=str(inc_error),
)
finally:
# Always stop listening
await incremental_scraped.stop_listen_dom_increment() await incremental_scraped.stop_listen_dom_increment()
return [ActionSuccess()] return [ActionSuccess()]
except Exception as e: except Exception as e:
# Handle any other unexpected errors during text input
LOG.exception( LOG.exception(
"Failed to input the value or finish the auto completion", "Failed to input the value or finish the auto completion",
task_id=task.task_id, task_id=task.task_id,
@@ -3738,6 +3936,21 @@ async def _get_input_or_select_context(
json_response = await app.PARSE_SELECT_LLM_API_HANDLER( json_response = await app.PARSE_SELECT_LLM_API_HANDLER(
prompt=prompt, step=step, prompt_name="parse-input-or-select-context" prompt=prompt, step=step, prompt_name="parse-input-or-select-context"
) )
# Handle edge case where LLM returns list instead of dict
if isinstance(json_response, list):
LOG.warning(
"LLM returned list instead of dict for input/select context parsing",
step_id=step.step_id,
original_response_type=type(json_response).__name__,
original_response_length=len(json_response) if json_response else 0,
first_item_type=type(json_response[0]).__name__ if json_response else None,
first_item_keys=list(json_response[0].keys())
if json_response and isinstance(json_response[0], dict)
else None,
)
json_response = json_response[0] if json_response else {}
json_response["intention"] = action.intention json_response["intention"] = action.intention
input_or_select_context = InputOrSelectContext.model_validate(json_response) input_or_select_context = InputOrSelectContext.model_validate(json_response)
LOG.info( LOG.info(