Pedro/support_multi_field_6_digit_totp (#3622)
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -10,6 +11,7 @@ from typing import Any, Awaitable, Callable, List
|
||||
|
||||
import pyotp
|
||||
import structlog
|
||||
from playwright._impl._errors import Error as PlaywrightError
|
||||
from playwright.async_api import FileChooser, Frame, Locator, Page, TimeoutError
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -22,6 +24,7 @@ from skyvern.constants import (
|
||||
REPO_ROOT_DIR,
|
||||
SKYVERN_ID_ATTR,
|
||||
)
|
||||
from skyvern.errors.errors import TOTPExpiredError
|
||||
from skyvern.exceptions import (
|
||||
DownloadFileMaxWaitingTime,
|
||||
EmptySelect,
|
||||
@@ -929,6 +932,128 @@ async def handle_click_to_download_file_action(
|
||||
return [ActionSuccess(download_triggered=True)]
|
||||
|
||||
|
||||
# TOTP timing constants
|
||||
TOTP_TIME_STEP_SECONDS = 30
|
||||
TOTP_EXPIRY_THRESHOLD_SECONDS = 20
|
||||
|
||||
|
||||
async def _handle_multi_field_totp_sequence(
|
||||
timing_info: dict[str, Any],
|
||||
task: Task,
|
||||
) -> list[ActionResult] | None:
|
||||
"""
|
||||
Handle TOTP generation and caching for multi-field TOTP sequences.
|
||||
|
||||
Returns:
|
||||
ActionFailure if TOTP handling failed, None if successful
|
||||
"""
|
||||
action_index = timing_info["action_index"]
|
||||
cache_key = f"{task.task_id}_totp_cache"
|
||||
current_context = skyvern_context.ensure_context()
|
||||
|
||||
if action_index == 0:
|
||||
# First digit: generate TOTP and cache it
|
||||
totp_secret = timing_info["totp_secret"]
|
||||
totp = pyotp.TOTP(totp_secret)
|
||||
|
||||
# Check current TOTP expiry time
|
||||
current_time = int(time.time())
|
||||
current_totp_valid_until = ((current_time // TOTP_TIME_STEP_SECONDS) + 1) * TOTP_TIME_STEP_SECONDS
|
||||
seconds_until_expiry = current_totp_valid_until - current_time
|
||||
|
||||
# If less than threshold seconds until expiry, use the next TOTP
|
||||
if seconds_until_expiry < TOTP_EXPIRY_THRESHOLD_SECONDS:
|
||||
# Force generation of next TOTP by advancing time
|
||||
next_time = current_totp_valid_until
|
||||
current_totp = totp.at(next_time)
|
||||
|
||||
LOG.debug(
|
||||
"Using multi-field TOTP flow - using NEXT TOTP due to <20s expiry",
|
||||
task_id=task.task_id,
|
||||
action_idx=action_index,
|
||||
current_totp=totp.now(),
|
||||
next_totp=current_totp,
|
||||
seconds_until_expiry=seconds_until_expiry,
|
||||
is_retry=timing_info.get("is_retry", False),
|
||||
)
|
||||
else:
|
||||
# Use current TOTP
|
||||
current_totp = totp.now()
|
||||
|
||||
current_context.totp_codes[cache_key] = current_totp
|
||||
else:
|
||||
# Subsequent digits: reuse cached TOTP
|
||||
current_totp = current_context.totp_codes.get(cache_key)
|
||||
if not current_totp:
|
||||
# TOTP cache missing for subsequent digit - this should not happen
|
||||
# If it does, something went wrong with the first digit, so fail the action
|
||||
LOG.error(
|
||||
"TOTP cache missing for subsequent digit - first digit may have failed",
|
||||
task_id=task.task_id,
|
||||
action_idx=action_index,
|
||||
cache_key=cache_key,
|
||||
)
|
||||
return [ActionFailure(TOTPExpiredError())]
|
||||
|
||||
# Check if cached TOTP has expired
|
||||
totp_secret = timing_info["totp_secret"]
|
||||
totp = pyotp.TOTP(totp_secret)
|
||||
|
||||
# Get current time and calculate TOTP expiry
|
||||
current_time = int(time.time())
|
||||
totp_valid_until = ((current_time // TOTP_TIME_STEP_SECONDS) + 1) * TOTP_TIME_STEP_SECONDS
|
||||
|
||||
if current_time >= totp_valid_until:
|
||||
LOG.error(
|
||||
"Cached TOTP has expired during multi-field sequence",
|
||||
task_id=task.task_id,
|
||||
action_idx=action_index,
|
||||
current_time=current_time,
|
||||
totp_valid_until=totp_valid_until,
|
||||
cached_totp=current_totp,
|
||||
)
|
||||
return [ActionFailure(TOTPExpiredError())]
|
||||
|
||||
LOG.debug(
|
||||
"Using multi-field TOTP flow - reusing cached TOTP",
|
||||
task_id=task.task_id,
|
||||
action_idx=action_index,
|
||||
totp=current_totp,
|
||||
current_time=current_time,
|
||||
totp_valid_until=totp_valid_until,
|
||||
)
|
||||
|
||||
# Special handling for the 6th digit (action_index=5): wait if TOTP is not yet valid
|
||||
if action_index == 5:
|
||||
# Calculate when this TOTP becomes valid (valid_from time)
|
||||
# If we used the next TOTP window, valid_from is the start of that window
|
||||
totp_valid_from = totp_valid_until - TOTP_TIME_STEP_SECONDS
|
||||
|
||||
if current_time < totp_valid_from:
|
||||
# TOTP is not yet valid, wait until it becomes valid
|
||||
wait_seconds = totp_valid_from - current_time
|
||||
|
||||
LOG.debug(
|
||||
"6th digit: TOTP not yet valid, waiting until valid_from",
|
||||
task_id=task.task_id,
|
||||
action_idx=action_index,
|
||||
current_time=current_time,
|
||||
totp_valid_from=totp_valid_from,
|
||||
wait_seconds=wait_seconds,
|
||||
totp=current_totp,
|
||||
)
|
||||
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
LOG.debug(
|
||||
"6th digit: Finished waiting, TOTP is now valid",
|
||||
task_id=task.task_id,
|
||||
action_idx=action_index,
|
||||
)
|
||||
|
||||
return None # Success
|
||||
|
||||
|
||||
@TraceManager.traced_async(ignore_inputs=["scraped_page", "page"])
|
||||
async def handle_input_text_action(
|
||||
action: actions.InputTextAction,
|
||||
@@ -954,9 +1079,17 @@ async def handle_input_text_action(
|
||||
|
||||
# before filling text, we need to validate if the element can be filled if it's not one of COMMON_INPUT_TAGS
|
||||
tag_name = scraped_page.id_to_element_dict[action.element_id]["tagName"].lower()
|
||||
text: str | None = await get_actual_value_of_parameter_if_secret(task, action.text)
|
||||
if text is None:
|
||||
return [ActionFailure(FailedToFetchSecret())]
|
||||
|
||||
# Check if this is multi-field TOTP first - if so, skip secret resolution
|
||||
if action.totp_timing_info and action.totp_timing_info.get("is_totp_sequence"):
|
||||
# For multi-field TOTP, we'll set text directly in the TOTP logic below
|
||||
text: str = ""
|
||||
else:
|
||||
# For regular inputs, resolve secrets
|
||||
text_result = await get_actual_value_of_parameter_if_secret(task, action.text)
|
||||
if text_result is None:
|
||||
return [ActionFailure(FailedToFetchSecret())]
|
||||
text = text_result
|
||||
|
||||
is_totp_value = (
|
||||
text == BitwardenConstants.TOTP or text == OnePasswordConstants.TOTP or text == AzureVaultConstants.TOTP
|
||||
@@ -1198,6 +1331,32 @@ async def handle_input_text_action(
|
||||
await skyvern_element.input(text)
|
||||
return [ActionSuccess()]
|
||||
|
||||
# Handle TOTP generation for multi-field TOTP sequences
|
||||
if action.totp_timing_info:
|
||||
timing_info = action.totp_timing_info
|
||||
if timing_info.get("is_totp_sequence"):
|
||||
result = await _handle_multi_field_totp_sequence(timing_info, task)
|
||||
if result is not None:
|
||||
return result # Return ActionFailure if TOTP handling failed
|
||||
|
||||
# Extract the digit for this action index
|
||||
current_totp = skyvern_context.ensure_context().totp_codes.get(f"{task.task_id}_totp_cache")
|
||||
action_index = timing_info["action_index"]
|
||||
|
||||
if current_totp and len(current_totp) > action_index:
|
||||
digit = current_totp[action_index]
|
||||
action.text = digit
|
||||
# Also update the text variable that will be used later
|
||||
text = digit
|
||||
else:
|
||||
LOG.error(
|
||||
"TOTP too short for action index",
|
||||
task_id=task.task_id,
|
||||
action_idx=action_index,
|
||||
totp_length=len(current_totp) if current_totp else 0,
|
||||
)
|
||||
return [ActionFailure(TOTPExpiredError())]
|
||||
|
||||
try:
|
||||
# TODO: not sure if this case will trigger auto-completion
|
||||
if tag_name not in COMMON_INPUT_TAGS:
|
||||
@@ -1246,18 +1405,57 @@ async def handle_input_text_action(
|
||||
|
||||
try:
|
||||
await skyvern_element.input_sequentially(text=text)
|
||||
finally:
|
||||
|
||||
incremental_element = await incremental_scraped.get_incremental_element_tree(
|
||||
clean_and_remove_element_tree_factory(
|
||||
task=task, step=step, check_filter_funcs=[check_existed_but_not_option_element_in_dom_factory(dom)]
|
||||
task=task,
|
||||
step=step,
|
||||
check_filter_funcs=[check_existed_but_not_option_element_in_dom_factory(dom)],
|
||||
),
|
||||
)
|
||||
if len(incremental_element) > 0:
|
||||
auto_complete_hacky_flag = True
|
||||
except PlaywrightError as inc_error:
|
||||
# Handle Playwright-specific errors during incremental element processing (e.g., TOTP form auto-submit)
|
||||
error_message = str(inc_error).lower()
|
||||
if (
|
||||
"execution context was destroyed" in error_message
|
||||
or "navigation" in error_message
|
||||
or "target closed" in error_message
|
||||
):
|
||||
# These are expected during page navigation/auto-submit, silently continue
|
||||
LOG.debug(
|
||||
"Playwright error during incremental element processing (likely page navigation)",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_type=type(inc_error).__name__,
|
||||
error_message=error_message,
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Unexpected Playwright error during incremental element processing",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_type=type(inc_error).__name__,
|
||||
error_message=str(inc_error),
|
||||
)
|
||||
except Exception as inc_error:
|
||||
# Handle any other unexpected errors during incremental element processing
|
||||
LOG.warning(
|
||||
"Unexpected error during incremental element processing",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_type=type(inc_error).__name__,
|
||||
error_message=str(inc_error),
|
||||
)
|
||||
finally:
|
||||
# Always stop listening
|
||||
await incremental_scraped.stop_listen_dom_increment()
|
||||
|
||||
return [ActionSuccess()]
|
||||
except Exception as e:
|
||||
# Handle any other unexpected errors during text input
|
||||
|
||||
LOG.exception(
|
||||
"Failed to input the value or finish the auto completion",
|
||||
task_id=task.task_id,
|
||||
@@ -3738,6 +3936,21 @@ async def _get_input_or_select_context(
|
||||
json_response = await app.PARSE_SELECT_LLM_API_HANDLER(
|
||||
prompt=prompt, step=step, prompt_name="parse-input-or-select-context"
|
||||
)
|
||||
|
||||
# Handle edge case where LLM returns list instead of dict
|
||||
if isinstance(json_response, list):
|
||||
LOG.warning(
|
||||
"LLM returned list instead of dict for input/select context parsing",
|
||||
step_id=step.step_id,
|
||||
original_response_type=type(json_response).__name__,
|
||||
original_response_length=len(json_response) if json_response else 0,
|
||||
first_item_type=type(json_response[0]).__name__ if json_response else None,
|
||||
first_item_keys=list(json_response[0].keys())
|
||||
if json_response and isinstance(json_response[0], dict)
|
||||
else None,
|
||||
)
|
||||
json_response = json_response[0] if json_response else {}
|
||||
|
||||
json_response["intention"] = action.intention
|
||||
input_or_select_context = InputOrSelectContext.model_validate(json_response)
|
||||
LOG.info(
|
||||
|
||||
Reference in New Issue
Block a user