Pedro/support_multi_field_6_digit_totp (#3622)

This commit is contained in:
pedrohsdb
2025-10-06 16:37:35 -07:00
committed by GitHub
parent 7f6e5d2e36
commit 6fc56d9775
4 changed files with 415 additions and 8 deletions

View File

@@ -2,6 +2,7 @@ import asyncio
import copy
import json
import os
import time
import urllib.parse
import uuid
from datetime import datetime, timedelta, timezone
@@ -10,6 +11,7 @@ from typing import Any, Awaitable, Callable, List
import pyotp
import structlog
from playwright._impl._errors import Error as PlaywrightError
from playwright.async_api import FileChooser, Frame, Locator, Page, TimeoutError
from pydantic import BaseModel
@@ -22,6 +24,7 @@ from skyvern.constants import (
REPO_ROOT_DIR,
SKYVERN_ID_ATTR,
)
from skyvern.errors.errors import TOTPExpiredError
from skyvern.exceptions import (
DownloadFileMaxWaitingTime,
EmptySelect,
@@ -929,6 +932,128 @@ async def handle_click_to_download_file_action(
return [ActionSuccess(download_triggered=True)]
# TOTP timing constants
TOTP_TIME_STEP_SECONDS = 30
TOTP_EXPIRY_THRESHOLD_SECONDS = 20
async def _handle_multi_field_totp_sequence(
timing_info: dict[str, Any],
task: Task,
) -> list[ActionResult] | None:
"""
Handle TOTP generation and caching for multi-field TOTP sequences.
Returns:
ActionFailure if TOTP handling failed, None if successful
"""
action_index = timing_info["action_index"]
cache_key = f"{task.task_id}_totp_cache"
current_context = skyvern_context.ensure_context()
if action_index == 0:
# First digit: generate TOTP and cache it
totp_secret = timing_info["totp_secret"]
totp = pyotp.TOTP(totp_secret)
# Check current TOTP expiry time
current_time = int(time.time())
current_totp_valid_until = ((current_time // TOTP_TIME_STEP_SECONDS) + 1) * TOTP_TIME_STEP_SECONDS
seconds_until_expiry = current_totp_valid_until - current_time
# If less than threshold seconds until expiry, use the next TOTP
if seconds_until_expiry < TOTP_EXPIRY_THRESHOLD_SECONDS:
# Force generation of next TOTP by advancing time
next_time = current_totp_valid_until
current_totp = totp.at(next_time)
LOG.debug(
"Using multi-field TOTP flow - using NEXT TOTP due to <20s expiry",
task_id=task.task_id,
action_idx=action_index,
current_totp=totp.now(),
next_totp=current_totp,
seconds_until_expiry=seconds_until_expiry,
is_retry=timing_info.get("is_retry", False),
)
else:
# Use current TOTP
current_totp = totp.now()
current_context.totp_codes[cache_key] = current_totp
else:
# Subsequent digits: reuse cached TOTP
current_totp = current_context.totp_codes.get(cache_key)
if not current_totp:
# TOTP cache missing for subsequent digit - this should not happen
# If it does, something went wrong with the first digit, so fail the action
LOG.error(
"TOTP cache missing for subsequent digit - first digit may have failed",
task_id=task.task_id,
action_idx=action_index,
cache_key=cache_key,
)
return [ActionFailure(TOTPExpiredError())]
# Check if cached TOTP has expired
totp_secret = timing_info["totp_secret"]
totp = pyotp.TOTP(totp_secret)
# Get current time and calculate TOTP expiry
current_time = int(time.time())
totp_valid_until = ((current_time // TOTP_TIME_STEP_SECONDS) + 1) * TOTP_TIME_STEP_SECONDS
if current_time >= totp_valid_until:
LOG.error(
"Cached TOTP has expired during multi-field sequence",
task_id=task.task_id,
action_idx=action_index,
current_time=current_time,
totp_valid_until=totp_valid_until,
cached_totp=current_totp,
)
return [ActionFailure(TOTPExpiredError())]
LOG.debug(
"Using multi-field TOTP flow - reusing cached TOTP",
task_id=task.task_id,
action_idx=action_index,
totp=current_totp,
current_time=current_time,
totp_valid_until=totp_valid_until,
)
# Special handling for the 6th digit (action_index=5): wait if TOTP is not yet valid
if action_index == 5:
# Calculate when this TOTP becomes valid (valid_from time)
# If we used the next TOTP window, valid_from is the start of that window
totp_valid_from = totp_valid_until - TOTP_TIME_STEP_SECONDS
if current_time < totp_valid_from:
# TOTP is not yet valid, wait until it becomes valid
wait_seconds = totp_valid_from - current_time
LOG.debug(
"6th digit: TOTP not yet valid, waiting until valid_from",
task_id=task.task_id,
action_idx=action_index,
current_time=current_time,
totp_valid_from=totp_valid_from,
wait_seconds=wait_seconds,
totp=current_totp,
)
await asyncio.sleep(wait_seconds)
LOG.debug(
"6th digit: Finished waiting, TOTP is now valid",
task_id=task.task_id,
action_idx=action_index,
)
return None # Success
@TraceManager.traced_async(ignore_inputs=["scraped_page", "page"])
async def handle_input_text_action(
action: actions.InputTextAction,
@@ -954,9 +1079,17 @@ async def handle_input_text_action(
# before filling text, we need to validate if the element can be filled if it's not one of COMMON_INPUT_TAGS
tag_name = scraped_page.id_to_element_dict[action.element_id]["tagName"].lower()
text: str | None = await get_actual_value_of_parameter_if_secret(task, action.text)
if text is None:
return [ActionFailure(FailedToFetchSecret())]
# Check if this is multi-field TOTP first - if so, skip secret resolution
if action.totp_timing_info and action.totp_timing_info.get("is_totp_sequence"):
# For multi-field TOTP, we'll set text directly in the TOTP logic below
text: str = ""
else:
# For regular inputs, resolve secrets
text_result = await get_actual_value_of_parameter_if_secret(task, action.text)
if text_result is None:
return [ActionFailure(FailedToFetchSecret())]
text = text_result
is_totp_value = (
text == BitwardenConstants.TOTP or text == OnePasswordConstants.TOTP or text == AzureVaultConstants.TOTP
@@ -1198,6 +1331,32 @@ async def handle_input_text_action(
await skyvern_element.input(text)
return [ActionSuccess()]
# Handle TOTP generation for multi-field TOTP sequences
if action.totp_timing_info:
timing_info = action.totp_timing_info
if timing_info.get("is_totp_sequence"):
result = await _handle_multi_field_totp_sequence(timing_info, task)
if result is not None:
return result # Return ActionFailure if TOTP handling failed
# Extract the digit for this action index
current_totp = skyvern_context.ensure_context().totp_codes.get(f"{task.task_id}_totp_cache")
action_index = timing_info["action_index"]
if current_totp and len(current_totp) > action_index:
digit = current_totp[action_index]
action.text = digit
# Also update the text variable that will be used later
text = digit
else:
LOG.error(
"TOTP too short for action index",
task_id=task.task_id,
action_idx=action_index,
totp_length=len(current_totp) if current_totp else 0,
)
return [ActionFailure(TOTPExpiredError())]
try:
# TODO: not sure if this case will trigger auto-completion
if tag_name not in COMMON_INPUT_TAGS:
@@ -1246,18 +1405,57 @@ async def handle_input_text_action(
try:
await skyvern_element.input_sequentially(text=text)
finally:
incremental_element = await incremental_scraped.get_incremental_element_tree(
clean_and_remove_element_tree_factory(
task=task, step=step, check_filter_funcs=[check_existed_but_not_option_element_in_dom_factory(dom)]
task=task,
step=step,
check_filter_funcs=[check_existed_but_not_option_element_in_dom_factory(dom)],
),
)
if len(incremental_element) > 0:
auto_complete_hacky_flag = True
except PlaywrightError as inc_error:
# Handle Playwright-specific errors during incremental element processing (e.g., TOTP form auto-submit)
error_message = str(inc_error).lower()
if (
"execution context was destroyed" in error_message
or "navigation" in error_message
or "target closed" in error_message
):
# These are expected during page navigation/auto-submit, silently continue
LOG.debug(
"Playwright error during incremental element processing (likely page navigation)",
task_id=task.task_id,
step_id=step.step_id,
error_type=type(inc_error).__name__,
error_message=error_message,
)
else:
LOG.warning(
"Unexpected Playwright error during incremental element processing",
task_id=task.task_id,
step_id=step.step_id,
error_type=type(inc_error).__name__,
error_message=str(inc_error),
)
except Exception as inc_error:
# Handle any other unexpected errors during incremental element processing
LOG.warning(
"Unexpected error during incremental element processing",
task_id=task.task_id,
step_id=step.step_id,
error_type=type(inc_error).__name__,
error_message=str(inc_error),
)
finally:
# Always stop listening
await incremental_scraped.stop_listen_dom_increment()
return [ActionSuccess()]
except Exception as e:
# Handle any other unexpected errors during text input
LOG.exception(
"Failed to input the value or finish the auto completion",
task_id=task.task_id,
@@ -3738,6 +3936,21 @@ async def _get_input_or_select_context(
json_response = await app.PARSE_SELECT_LLM_API_HANDLER(
prompt=prompt, step=step, prompt_name="parse-input-or-select-context"
)
# Handle edge case where LLM returns list instead of dict
if isinstance(json_response, list):
LOG.warning(
"LLM returned list instead of dict for input/select context parsing",
step_id=step.step_id,
original_response_type=type(json_response).__name__,
original_response_length=len(json_response) if json_response else 0,
first_item_type=type(json_response[0]).__name__ if json_response else None,
first_item_keys=list(json_response[0].keys())
if json_response and isinstance(json_response[0], dict)
else None,
)
json_response = json_response[0] if json_response else {}
json_response["intention"] = action.intention
input_or_select_context = InputOrSelectContext.model_validate(json_response)
LOG.info(