Fix multi-field TOTP support in cached script execution (#4537)

This commit is contained in:
pedrohsdb
2026-01-26 17:12:25 -08:00
committed by GitHub
parent b92447df35
commit 16945e117f
3 changed files with 299 additions and 9 deletions

View File

@@ -201,6 +201,10 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool:
"""
Determine whether an input/select action should be forced into proactive mode.
Mirrors runtime logic that treats some inputs as mini-agent flows or TOTP-sensitive.
NOTE: Multi-field TOTP sequences do NOT require proactive mode because we use
get_totp_digit() to provide the exact digit value. Using proactive mode would
cause the AI to override our value with its own generated one.
"""
if act.get("has_mini_agent", False):
return True
@@ -211,12 +215,117 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool:
# ):
# return True
if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"):
return True
# Multi-field TOTP sequences should NOT use proactive mode - we provide the
# exact digit via get_totp_digit() and want that value used directly
# if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"):
# return True
return False
def _annotate_multi_field_totp_sequence(actions: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Detect and annotate multi-field TOTP sequences in the action list.
Multi-field TOTP is when a 6-digit code needs to be split across 6 individual input fields.
This function identifies such sequences and adds totp_timing_info with the action_index
so that each field gets the correct digit (e.g., totp_code[0], totp_code[1], etc.).
Args:
actions: List of actions to analyze and annotate
Returns:
The same actions list with totp_timing_info added to multi-field TOTP actions
"""
if len(actions) < 4:
return actions
# Identify consecutive runs of single-digit TOTP inputs
# A multi-field TOTP sequence is 4+ consecutive INPUT_TEXT actions with single-digit text
# and the same field_name (typically 'totp_code')
consecutive_start = None
consecutive_count = 0
totp_field_name = None
for idx, act in enumerate(actions):
is_single_digit_totp = (
act.get("action_type") == ActionType.INPUT_TEXT
and act.get("field_name")
and act.get("text")
and len(str(act.get("text", ""))) == 1
and str(act.get("text", "")).isdigit()
)
if is_single_digit_totp:
current_field = act.get("field_name")
if consecutive_start is None:
# Start a new sequence
consecutive_start = idx
totp_field_name = current_field
consecutive_count = 1
elif current_field == totp_field_name:
# Same field, continue the sequence
consecutive_count += 1
else:
# Different field - finalize current sequence if valid, then start new one
if consecutive_count >= 4:
for seq_idx in range(consecutive_count):
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
"is_totp_sequence": True,
"action_index": seq_idx,
"total_digits": consecutive_count,
"field_name": totp_field_name,
}
LOG.debug(
"Annotated multi-field TOTP sequence (field change)",
start_idx=consecutive_start,
count=consecutive_count,
field_name=totp_field_name,
)
# Start new sequence with different field
consecutive_start = idx
totp_field_name = current_field
consecutive_count = 1
else:
# End of consecutive sequence - check if it was a multi-field TOTP
if consecutive_count >= 4 and consecutive_start is not None:
# Annotate all actions in this sequence
for seq_idx in range(consecutive_count):
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
"is_totp_sequence": True,
"action_index": seq_idx,
"total_digits": consecutive_count,
"field_name": totp_field_name,
}
LOG.debug(
"Annotated multi-field TOTP sequence for script generation",
start_idx=consecutive_start,
count=consecutive_count,
field_name=totp_field_name,
)
consecutive_start = None
consecutive_count = 0
totp_field_name = None
# Handle sequence at end of actions list
if consecutive_count >= 4 and consecutive_start is not None:
for seq_idx in range(consecutive_count):
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
"is_totp_sequence": True,
"action_index": seq_idx,
"total_digits": consecutive_count,
"field_name": totp_field_name,
}
LOG.debug(
"Annotated multi-field TOTP sequence for script generation (at end)",
start_idx=consecutive_start,
count=consecutive_count,
field_name=totp_field_name,
)
return actions
def _safe_name(label: str) -> str:
s = "".join(c if c.isalnum() else "_" for c in label).lower()
if not s or s[0].isdigit() or keyword.iskeyword(s):
@@ -389,13 +498,32 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output:
elif method in ["type", "fill"]:
# Use context.parameters if field_name is available, otherwise fallback to direct value
if act.get("field_name"):
text_value = cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
)
# Check if this is a multi-field TOTP sequence that needs digit indexing
totp_info = act.get("totp_timing_info") or {}
if totp_info.get("is_totp_sequence") and "action_index" in totp_info:
# Generate: await page.get_totp_digit(context, 'field_name', digit_index)
# This method properly resolves the TOTP code from credentials and returns the specific digit
text_value = cst.Await(
expression=cst.Call(
func=cst.Attribute(
value=cst.Name("page"),
attr=cst.Name("get_totp_digit"),
),
args=[
cst.Arg(value=cst.Name("context")),
cst.Arg(value=_value(act["field_name"])),
cst.Arg(value=_value(totp_info["action_index"])),
],
)
)
else:
text_value = cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
)
else:
text_value = _value(act["text"])
@@ -644,6 +772,9 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
body_stmts: list[cst.BaseStatement] = []
# Detect and annotate multi-field TOTP sequences so each fill gets the correct digit index
actions = _annotate_multi_field_totp_sequence(actions)
if block.get("url"):
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))

View File

@@ -5,7 +5,9 @@ import os
from pathlib import Path
from typing import Any, Callable
import pyotp
import structlog
from cachetools import TTLCache
from playwright.async_api import Page
from skyvern.config import settings
@@ -591,6 +593,124 @@ class ScriptSkyvernPage(SkyvernPage):
return value
# Class-level cache for TOTP codes to ensure all digits in a sequence use the same code
# Key: (workflow_run_id, credential_key), Value: totp_code
# Uses TTLCache with 30-second expiry (aligned with TOTP rotation period)
# and max 100 entries to prevent unbounded memory growth
_totp_sequence_cache: TTLCache[tuple[str, str], str] = TTLCache(maxsize=100, ttl=30)
async def get_totp_digit(
self,
context: Any,
field_name: str,
digit_index: int,
totp_identifier: str | None = None,
totp_url: str | None = None,
) -> str:
"""
Get a specific digit from a TOTP code for multi-field TOTP inputs.
This method is used by generated scripts for multi-field TOTP where each
input field needs a single digit. It resolves the full TOTP code from
the credential and returns the specific digit.
IMPORTANT: When digit_index == 0, a fresh TOTP code is generated and cached.
For digit_index > 0, the cached code is used. This ensures all 6 digits
of a multi-field TOTP use the same code even if filling spans TOTP rotation
boundaries.
Args:
context: The run context containing parameters
field_name: The parameter name containing the TOTP code or credential reference
digit_index: The index of the digit to return (0-5 for a 6-digit TOTP)
totp_identifier: Optional TOTP identifier for polling
totp_url: Optional TOTP verification URL
Returns:
The single digit at the specified index
"""
totp_code = ""
skyvern_ctx = skyvern_context.ensure_context()
workflow_run_id = skyvern_ctx.workflow_run_id if skyvern_ctx else None
LOG.info(
"get_totp_digit called",
field_name=field_name,
digit_index=digit_index,
workflow_run_id=workflow_run_id,
)
# Get the raw parameter value (may be credential reference like BW_TOTP)
raw_value = context.parameters.get(field_name, "")
# If the direct field_name parameter is empty, try to find a credential TOTP
# by looking at the workflow run context for credential parameters
if not raw_value and skyvern_ctx and workflow_run_id:
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
if workflow_run_context:
# Look for credential parameters in the workflow run context values
for key, value in workflow_run_context.values.items():
if key.startswith("cred_") and isinstance(value, dict) and "totp" in value:
cache_key = (workflow_run_id, key)
# For digit_index == 0, clear any stale cache and generate fresh TOTP
# For digit_index > 0, use cached code if available
if digit_index == 0:
# Clear stale cache for new sequence, fall through to generate
if cache_key in self._totp_sequence_cache:
del self._totp_sequence_cache[cache_key]
elif cache_key in self._totp_sequence_cache:
# Use cached value for digit_index > 0
totp_code = self._totp_sequence_cache[cache_key]
LOG.info(
"Using cached TOTP code for sequence",
field_name=field_name,
credential_key=key,
digit_index=digit_index,
totp_code_length=len(totp_code),
)
break
# Generate new TOTP code (digit_index==0 or cache miss)
totp_secret_id = value.get("totp")
if totp_secret_id:
totp_secret_key = workflow_run_context.totp_secret_value_key(totp_secret_id)
totp_secret = workflow_run_context.get_original_secret_value_or_none(totp_secret_key)
if totp_secret:
try:
totp_code = pyotp.TOTP(totp_secret).now()
# Cache the code for subsequent digit requests in this sequence
self._totp_sequence_cache[cache_key] = totp_code
LOG.info(
"Generated fresh TOTP and cached for sequence",
field_name=field_name,
credential_key=key,
digit_index=digit_index,
totp_code_length=len(totp_code),
)
break
except Exception as e:
LOG.warning(
"Failed to generate TOTP code",
credential_key=key,
error=str(e),
)
# If we still don't have a TOTP code, try resolving via get_actual_value
if not totp_code:
totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url)
# Return the specific digit
if digit_index < len(totp_code):
return totp_code[digit_index]
LOG.warning(
"TOTP digit index out of range",
field_name=field_name,
digit_index=digit_index,
totp_code_length=len(totp_code),
)
return ""
async def goto(self, url: str, **kwargs: Any) -> None:
url = render_template(url)
url = prepend_scheme_and_validate_url(url)

View File

@@ -123,6 +123,40 @@ class SkyvernPage(Page):
) -> str:
return value
async def get_totp_digit(
self,
context: Any,
field_name: str,
digit_index: int,
totp_identifier: str | None = None,
totp_url: str | None = None,
) -> str:
"""
Get a specific digit from a TOTP code for multi-field TOTP inputs.
This method is used by generated scripts for multi-field TOTP where each
input field needs a single digit. It resolves the full TOTP code from
the credential and returns the specific digit.
Args:
context: The run context containing parameters
field_name: The parameter name containing the TOTP code or credential reference
digit_index: The index of the digit to return (0-5 for a 6-digit TOTP)
totp_identifier: Optional TOTP identifier for polling
totp_url: Optional TOTP verification URL
Returns:
The single digit at the specified index
"""
# Get the raw parameter value (may be credential reference like BW_TOTP)
raw_value = context.parameters.get(field_name, "")
# Resolve the actual TOTP code (this handles credential generation)
totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url)
# Return the specific digit
if digit_index < len(totp_code):
return totp_code[digit_index]
return ""
######### Public Interfaces #########
@overload
@@ -409,6 +443,11 @@ class SkyvernPage(Page):
if context and context.ai_mode_override:
ai = context.ai_mode_override
# For single-digit TOTP values (from multi-field TOTP inputs), force fallback mode
# so that we use the exact digit value instead of having AI generate a new one
if value and len(value) == 1 and value.isdigit() and ai == "proactive":
ai = "fallback"
# format the text with the actual value of the parameter if it's a secret when running a workflow
if ai == "fallback":
error_to_raise = None