From 16945e117f95d9317729d8feb7b950d30dca904f Mon Sep 17 00:00:00 2001 From: pedrohsdb Date: Mon, 26 Jan 2026 17:12:25 -0800 Subject: [PATCH] Fix multi-field TOTP support in cached script execution (#4537) --- .../script_generations/generate_script.py | 149 ++++++++++++++++-- .../script_generations/script_skyvern_page.py | 120 ++++++++++++++ .../core/script_generations/skyvern_page.py | 39 +++++ 3 files changed, 299 insertions(+), 9 deletions(-) diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index ee2691bd..69e26f5f 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -201,6 +201,10 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool: """ Determine whether an input/select action should be forced into proactive mode. Mirrors runtime logic that treats some inputs as mini-agent flows or TOTP-sensitive. + + NOTE: Multi-field TOTP sequences do NOT require proactive mode because we use + get_totp_digit() to provide the exact digit value. Using proactive mode would + cause the AI to override our value with its own generated one. """ if act.get("has_mini_agent", False): return True @@ -211,12 +215,117 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool: # ): # return True - if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"): - return True + # Multi-field TOTP sequences should NOT use proactive mode - we provide the + # exact digit via get_totp_digit() and want that value used directly + # if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"): + # return True return False +def _annotate_multi_field_totp_sequence(actions: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Detect and annotate multi-field TOTP sequences in the action list. + + Multi-field TOTP is when a 6-digit code needs to be split across 6 individual input fields. + This function identifies such sequences and adds totp_timing_info with the action_index + so that each field gets the correct digit (e.g., totp_code[0], totp_code[1], etc.). + + Args: + actions: List of actions to analyze and annotate + + Returns: + The same actions list with totp_timing_info added to multi-field TOTP actions + """ + if len(actions) < 4: + return actions + + # Identify consecutive runs of single-digit TOTP inputs + # A multi-field TOTP sequence is 4+ consecutive INPUT_TEXT actions with single-digit text + # and the same field_name (typically 'totp_code') + consecutive_start = None + consecutive_count = 0 + totp_field_name = None + + for idx, act in enumerate(actions): + is_single_digit_totp = ( + act.get("action_type") == ActionType.INPUT_TEXT + and act.get("field_name") + and act.get("text") + and len(str(act.get("text", ""))) == 1 + and str(act.get("text", "")).isdigit() + ) + + if is_single_digit_totp: + current_field = act.get("field_name") + if consecutive_start is None: + # Start a new sequence + consecutive_start = idx + totp_field_name = current_field + consecutive_count = 1 + elif current_field == totp_field_name: + # Same field, continue the sequence + consecutive_count += 1 + else: + # Different field - finalize current sequence if valid, then start new one + if consecutive_count >= 4: + for seq_idx in range(consecutive_count): + actions[consecutive_start + seq_idx]["totp_timing_info"] = { + "is_totp_sequence": True, + "action_index": seq_idx, + "total_digits": consecutive_count, + "field_name": totp_field_name, + } + LOG.debug( + "Annotated multi-field TOTP sequence (field change)", + start_idx=consecutive_start, + count=consecutive_count, + field_name=totp_field_name, + ) + # Start new sequence with different field + consecutive_start = idx + totp_field_name = current_field + consecutive_count = 1 + else: + # End of consecutive sequence - check if it was a multi-field TOTP + if consecutive_count >= 4 and consecutive_start is not None: + # Annotate all actions in this sequence + for seq_idx in range(consecutive_count): + actions[consecutive_start + seq_idx]["totp_timing_info"] = { + "is_totp_sequence": True, + "action_index": seq_idx, + "total_digits": consecutive_count, + "field_name": totp_field_name, + } + LOG.debug( + "Annotated multi-field TOTP sequence for script generation", + start_idx=consecutive_start, + count=consecutive_count, + field_name=totp_field_name, + ) + consecutive_start = None + consecutive_count = 0 + totp_field_name = None + + # Handle sequence at end of actions list + if consecutive_count >= 4 and consecutive_start is not None: + for seq_idx in range(consecutive_count): + actions[consecutive_start + seq_idx]["totp_timing_info"] = { + "is_totp_sequence": True, + "action_index": seq_idx, + "total_digits": consecutive_count, + "field_name": totp_field_name, + } + LOG.debug( + "Annotated multi-field TOTP sequence for script generation (at end)", + start_idx=consecutive_start, + count=consecutive_count, + field_name=totp_field_name, + ) + + return actions + + def _safe_name(label: str) -> str: s = "".join(c if c.isalnum() else "_" for c in label).lower() if not s or s[0].isdigit() or keyword.iskeyword(s): @@ -389,13 +498,32 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output: elif method in ["type", "fill"]: # Use context.parameters if field_name is available, otherwise fallback to direct value if act.get("field_name"): - text_value = cst.Subscript( - value=cst.Attribute( - value=cst.Name("context"), - attr=cst.Name("parameters"), - ), - slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))], - ) + # Check if this is a multi-field TOTP sequence that needs digit indexing + totp_info = act.get("totp_timing_info") or {} + if totp_info.get("is_totp_sequence") and "action_index" in totp_info: + # Generate: await page.get_totp_digit(context, 'field_name', digit_index) + # This method properly resolves the TOTP code from credentials and returns the specific digit + text_value = cst.Await( + expression=cst.Call( + func=cst.Attribute( + value=cst.Name("page"), + attr=cst.Name("get_totp_digit"), + ), + args=[ + cst.Arg(value=cst.Name("context")), + cst.Arg(value=_value(act["field_name"])), + cst.Arg(value=_value(totp_info["action_index"])), + ], + ) + ) + else: + text_value = cst.Subscript( + value=cst.Attribute( + value=cst.Name("context"), + attr=cst.Name("parameters"), + ), + slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))], + ) else: text_value = _value(act["text"]) @@ -644,6 +772,9 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}" body_stmts: list[cst.BaseStatement] = [] + # Detect and annotate multi-field TOTP sequences so each fill gets the correct digit index + actions = _annotate_multi_field_totp_sequence(actions) + if block.get("url"): body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})")) diff --git a/skyvern/core/script_generations/script_skyvern_page.py b/skyvern/core/script_generations/script_skyvern_page.py index c8a4f6da..5dd3e8f2 100644 --- a/skyvern/core/script_generations/script_skyvern_page.py +++ b/skyvern/core/script_generations/script_skyvern_page.py @@ -5,7 +5,9 @@ import os from pathlib import Path from typing import Any, Callable +import pyotp import structlog +from cachetools import TTLCache from playwright.async_api import Page from skyvern.config import settings @@ -591,6 +593,124 @@ class ScriptSkyvernPage(SkyvernPage): return value + # Class-level cache for TOTP codes to ensure all digits in a sequence use the same code + # Key: (workflow_run_id, credential_key), Value: totp_code + # Uses TTLCache with 30-second expiry (aligned with TOTP rotation period) + # and max 100 entries to prevent unbounded memory growth + _totp_sequence_cache: TTLCache[tuple[str, str], str] = TTLCache(maxsize=100, ttl=30) + + async def get_totp_digit( + self, + context: Any, + field_name: str, + digit_index: int, + totp_identifier: str | None = None, + totp_url: str | None = None, + ) -> str: + """ + Get a specific digit from a TOTP code for multi-field TOTP inputs. + + This method is used by generated scripts for multi-field TOTP where each + input field needs a single digit. It resolves the full TOTP code from + the credential and returns the specific digit. + + IMPORTANT: When digit_index == 0, a fresh TOTP code is generated and cached. + For digit_index > 0, the cached code is used. This ensures all 6 digits + of a multi-field TOTP use the same code even if filling spans TOTP rotation + boundaries. + + Args: + context: The run context containing parameters + field_name: The parameter name containing the TOTP code or credential reference + digit_index: The index of the digit to return (0-5 for a 6-digit TOTP) + totp_identifier: Optional TOTP identifier for polling + totp_url: Optional TOTP verification URL + + Returns: + The single digit at the specified index + """ + totp_code = "" + skyvern_ctx = skyvern_context.ensure_context() + workflow_run_id = skyvern_ctx.workflow_run_id if skyvern_ctx else None + + LOG.info( + "get_totp_digit called", + field_name=field_name, + digit_index=digit_index, + workflow_run_id=workflow_run_id, + ) + + # Get the raw parameter value (may be credential reference like BW_TOTP) + raw_value = context.parameters.get(field_name, "") + + # If the direct field_name parameter is empty, try to find a credential TOTP + # by looking at the workflow run context for credential parameters + if not raw_value and skyvern_ctx and workflow_run_id: + workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id) + if workflow_run_context: + # Look for credential parameters in the workflow run context values + for key, value in workflow_run_context.values.items(): + if key.startswith("cred_") and isinstance(value, dict) and "totp" in value: + cache_key = (workflow_run_id, key) + + # For digit_index == 0, clear any stale cache and generate fresh TOTP + # For digit_index > 0, use cached code if available + if digit_index == 0: + # Clear stale cache for new sequence, fall through to generate + if cache_key in self._totp_sequence_cache: + del self._totp_sequence_cache[cache_key] + elif cache_key in self._totp_sequence_cache: + # Use cached value for digit_index > 0 + totp_code = self._totp_sequence_cache[cache_key] + LOG.info( + "Using cached TOTP code for sequence", + field_name=field_name, + credential_key=key, + digit_index=digit_index, + totp_code_length=len(totp_code), + ) + break + + # Generate new TOTP code (digit_index==0 or cache miss) + totp_secret_id = value.get("totp") + if totp_secret_id: + totp_secret_key = workflow_run_context.totp_secret_value_key(totp_secret_id) + totp_secret = workflow_run_context.get_original_secret_value_or_none(totp_secret_key) + if totp_secret: + try: + totp_code = pyotp.TOTP(totp_secret).now() + # Cache the code for subsequent digit requests in this sequence + self._totp_sequence_cache[cache_key] = totp_code + LOG.info( + "Generated fresh TOTP and cached for sequence", + field_name=field_name, + credential_key=key, + digit_index=digit_index, + totp_code_length=len(totp_code), + ) + break + except Exception as e: + LOG.warning( + "Failed to generate TOTP code", + credential_key=key, + error=str(e), + ) + + # If we still don't have a TOTP code, try resolving via get_actual_value + if not totp_code: + totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url) + + # Return the specific digit + if digit_index < len(totp_code): + return totp_code[digit_index] + LOG.warning( + "TOTP digit index out of range", + field_name=field_name, + digit_index=digit_index, + totp_code_length=len(totp_code), + ) + return "" + async def goto(self, url: str, **kwargs: Any) -> None: url = render_template(url) url = prepend_scheme_and_validate_url(url) diff --git a/skyvern/core/script_generations/skyvern_page.py b/skyvern/core/script_generations/skyvern_page.py index 56def805..058b8369 100644 --- a/skyvern/core/script_generations/skyvern_page.py +++ b/skyvern/core/script_generations/skyvern_page.py @@ -123,6 +123,40 @@ class SkyvernPage(Page): ) -> str: return value + async def get_totp_digit( + self, + context: Any, + field_name: str, + digit_index: int, + totp_identifier: str | None = None, + totp_url: str | None = None, + ) -> str: + """ + Get a specific digit from a TOTP code for multi-field TOTP inputs. + + This method is used by generated scripts for multi-field TOTP where each + input field needs a single digit. It resolves the full TOTP code from + the credential and returns the specific digit. + + Args: + context: The run context containing parameters + field_name: The parameter name containing the TOTP code or credential reference + digit_index: The index of the digit to return (0-5 for a 6-digit TOTP) + totp_identifier: Optional TOTP identifier for polling + totp_url: Optional TOTP verification URL + + Returns: + The single digit at the specified index + """ + # Get the raw parameter value (may be credential reference like BW_TOTP) + raw_value = context.parameters.get(field_name, "") + # Resolve the actual TOTP code (this handles credential generation) + totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url) + # Return the specific digit + if digit_index < len(totp_code): + return totp_code[digit_index] + return "" + ######### Public Interfaces ######### @overload @@ -409,6 +443,11 @@ class SkyvernPage(Page): if context and context.ai_mode_override: ai = context.ai_mode_override + # For single-digit TOTP values (from multi-field TOTP inputs), force fallback mode + # so that we use the exact digit value instead of having AI generate a new one + if value and len(value) == 1 and value.isdigit() and ai == "proactive": + ai = "fallback" + # format the text with the actual value of the parameter if it's a secret when running a workflow if ai == "fallback": error_to_raise = None