Fix multi-field TOTP support in cached script execution (#4537)

This commit is contained in:
pedrohsdb
2026-01-26 17:12:25 -08:00
committed by GitHub
parent b92447df35
commit 16945e117f
3 changed files with 299 additions and 9 deletions

View File

@@ -201,6 +201,10 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool:
""" """
Determine whether an input/select action should be forced into proactive mode. Determine whether an input/select action should be forced into proactive mode.
Mirrors runtime logic that treats some inputs as mini-agent flows or TOTP-sensitive. Mirrors runtime logic that treats some inputs as mini-agent flows or TOTP-sensitive.
NOTE: Multi-field TOTP sequences do NOT require proactive mode because we use
get_totp_digit() to provide the exact digit value. Using proactive mode would
cause the AI to override our value with its own generated one.
""" """
if act.get("has_mini_agent", False): if act.get("has_mini_agent", False):
return True return True
@@ -211,12 +215,117 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool:
# ): # ):
# return True # return True
if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"): # Multi-field TOTP sequences should NOT use proactive mode - we provide the
return True # exact digit via get_totp_digit() and want that value used directly
# if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"):
# return True
return False return False
def _annotate_multi_field_totp_sequence(actions: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Detect and annotate multi-field TOTP sequences in the action list.
Multi-field TOTP is when a 6-digit code needs to be split across 6 individual input fields.
This function identifies such sequences and adds totp_timing_info with the action_index
so that each field gets the correct digit (e.g., totp_code[0], totp_code[1], etc.).
Args:
actions: List of actions to analyze and annotate
Returns:
The same actions list with totp_timing_info added to multi-field TOTP actions
"""
if len(actions) < 4:
return actions
# Identify consecutive runs of single-digit TOTP inputs
# A multi-field TOTP sequence is 4+ consecutive INPUT_TEXT actions with single-digit text
# and the same field_name (typically 'totp_code')
consecutive_start = None
consecutive_count = 0
totp_field_name = None
for idx, act in enumerate(actions):
is_single_digit_totp = (
act.get("action_type") == ActionType.INPUT_TEXT
and act.get("field_name")
and act.get("text")
and len(str(act.get("text", ""))) == 1
and str(act.get("text", "")).isdigit()
)
if is_single_digit_totp:
current_field = act.get("field_name")
if consecutive_start is None:
# Start a new sequence
consecutive_start = idx
totp_field_name = current_field
consecutive_count = 1
elif current_field == totp_field_name:
# Same field, continue the sequence
consecutive_count += 1
else:
# Different field - finalize current sequence if valid, then start new one
if consecutive_count >= 4:
for seq_idx in range(consecutive_count):
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
"is_totp_sequence": True,
"action_index": seq_idx,
"total_digits": consecutive_count,
"field_name": totp_field_name,
}
LOG.debug(
"Annotated multi-field TOTP sequence (field change)",
start_idx=consecutive_start,
count=consecutive_count,
field_name=totp_field_name,
)
# Start new sequence with different field
consecutive_start = idx
totp_field_name = current_field
consecutive_count = 1
else:
# End of consecutive sequence - check if it was a multi-field TOTP
if consecutive_count >= 4 and consecutive_start is not None:
# Annotate all actions in this sequence
for seq_idx in range(consecutive_count):
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
"is_totp_sequence": True,
"action_index": seq_idx,
"total_digits": consecutive_count,
"field_name": totp_field_name,
}
LOG.debug(
"Annotated multi-field TOTP sequence for script generation",
start_idx=consecutive_start,
count=consecutive_count,
field_name=totp_field_name,
)
consecutive_start = None
consecutive_count = 0
totp_field_name = None
# Handle sequence at end of actions list
if consecutive_count >= 4 and consecutive_start is not None:
for seq_idx in range(consecutive_count):
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
"is_totp_sequence": True,
"action_index": seq_idx,
"total_digits": consecutive_count,
"field_name": totp_field_name,
}
LOG.debug(
"Annotated multi-field TOTP sequence for script generation (at end)",
start_idx=consecutive_start,
count=consecutive_count,
field_name=totp_field_name,
)
return actions
def _safe_name(label: str) -> str: def _safe_name(label: str) -> str:
s = "".join(c if c.isalnum() else "_" for c in label).lower() s = "".join(c if c.isalnum() else "_" for c in label).lower()
if not s or s[0].isdigit() or keyword.iskeyword(s): if not s or s[0].isdigit() or keyword.iskeyword(s):
@@ -389,13 +498,32 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output:
elif method in ["type", "fill"]: elif method in ["type", "fill"]:
# Use context.parameters if field_name is available, otherwise fallback to direct value # Use context.parameters if field_name is available, otherwise fallback to direct value
if act.get("field_name"): if act.get("field_name"):
text_value = cst.Subscript( # Check if this is a multi-field TOTP sequence that needs digit indexing
value=cst.Attribute( totp_info = act.get("totp_timing_info") or {}
value=cst.Name("context"), if totp_info.get("is_totp_sequence") and "action_index" in totp_info:
attr=cst.Name("parameters"), # Generate: await page.get_totp_digit(context, 'field_name', digit_index)
), # This method properly resolves the TOTP code from credentials and returns the specific digit
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))], text_value = cst.Await(
) expression=cst.Call(
func=cst.Attribute(
value=cst.Name("page"),
attr=cst.Name("get_totp_digit"),
),
args=[
cst.Arg(value=cst.Name("context")),
cst.Arg(value=_value(act["field_name"])),
cst.Arg(value=_value(totp_info["action_index"])),
],
)
)
else:
text_value = cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
)
else: else:
text_value = _value(act["text"]) text_value = _value(act["text"])
@@ -644,6 +772,9 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}" cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
body_stmts: list[cst.BaseStatement] = [] body_stmts: list[cst.BaseStatement] = []
# Detect and annotate multi-field TOTP sequences so each fill gets the correct digit index
actions = _annotate_multi_field_totp_sequence(actions)
if block.get("url"): if block.get("url"):
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})")) body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))

View File

@@ -5,7 +5,9 @@ import os
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable
import pyotp
import structlog import structlog
from cachetools import TTLCache
from playwright.async_api import Page from playwright.async_api import Page
from skyvern.config import settings from skyvern.config import settings
@@ -591,6 +593,124 @@ class ScriptSkyvernPage(SkyvernPage):
return value return value
# Class-level cache for TOTP codes to ensure all digits in a sequence use the same code
# Key: (workflow_run_id, credential_key), Value: totp_code
# Uses TTLCache with 30-second expiry (aligned with TOTP rotation period)
# and max 100 entries to prevent unbounded memory growth
_totp_sequence_cache: TTLCache[tuple[str, str], str] = TTLCache(maxsize=100, ttl=30)
async def get_totp_digit(
self,
context: Any,
field_name: str,
digit_index: int,
totp_identifier: str | None = None,
totp_url: str | None = None,
) -> str:
"""
Get a specific digit from a TOTP code for multi-field TOTP inputs.
This method is used by generated scripts for multi-field TOTP where each
input field needs a single digit. It resolves the full TOTP code from
the credential and returns the specific digit.
IMPORTANT: When digit_index == 0, a fresh TOTP code is generated and cached.
For digit_index > 0, the cached code is used. This ensures all 6 digits
of a multi-field TOTP use the same code even if filling spans TOTP rotation
boundaries.
Args:
context: The run context containing parameters
field_name: The parameter name containing the TOTP code or credential reference
digit_index: The index of the digit to return (0-5 for a 6-digit TOTP)
totp_identifier: Optional TOTP identifier for polling
totp_url: Optional TOTP verification URL
Returns:
The single digit at the specified index
"""
totp_code = ""
skyvern_ctx = skyvern_context.ensure_context()
workflow_run_id = skyvern_ctx.workflow_run_id if skyvern_ctx else None
LOG.info(
"get_totp_digit called",
field_name=field_name,
digit_index=digit_index,
workflow_run_id=workflow_run_id,
)
# Get the raw parameter value (may be credential reference like BW_TOTP)
raw_value = context.parameters.get(field_name, "")
# If the direct field_name parameter is empty, try to find a credential TOTP
# by looking at the workflow run context for credential parameters
if not raw_value and skyvern_ctx and workflow_run_id:
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
if workflow_run_context:
# Look for credential parameters in the workflow run context values
for key, value in workflow_run_context.values.items():
if key.startswith("cred_") and isinstance(value, dict) and "totp" in value:
cache_key = (workflow_run_id, key)
# For digit_index == 0, clear any stale cache and generate fresh TOTP
# For digit_index > 0, use cached code if available
if digit_index == 0:
# Clear stale cache for new sequence, fall through to generate
if cache_key in self._totp_sequence_cache:
del self._totp_sequence_cache[cache_key]
elif cache_key in self._totp_sequence_cache:
# Use cached value for digit_index > 0
totp_code = self._totp_sequence_cache[cache_key]
LOG.info(
"Using cached TOTP code for sequence",
field_name=field_name,
credential_key=key,
digit_index=digit_index,
totp_code_length=len(totp_code),
)
break
# Generate new TOTP code (digit_index==0 or cache miss)
totp_secret_id = value.get("totp")
if totp_secret_id:
totp_secret_key = workflow_run_context.totp_secret_value_key(totp_secret_id)
totp_secret = workflow_run_context.get_original_secret_value_or_none(totp_secret_key)
if totp_secret:
try:
totp_code = pyotp.TOTP(totp_secret).now()
# Cache the code for subsequent digit requests in this sequence
self._totp_sequence_cache[cache_key] = totp_code
LOG.info(
"Generated fresh TOTP and cached for sequence",
field_name=field_name,
credential_key=key,
digit_index=digit_index,
totp_code_length=len(totp_code),
)
break
except Exception as e:
LOG.warning(
"Failed to generate TOTP code",
credential_key=key,
error=str(e),
)
# If we still don't have a TOTP code, try resolving via get_actual_value
if not totp_code:
totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url)
# Return the specific digit
if digit_index < len(totp_code):
return totp_code[digit_index]
LOG.warning(
"TOTP digit index out of range",
field_name=field_name,
digit_index=digit_index,
totp_code_length=len(totp_code),
)
return ""
async def goto(self, url: str, **kwargs: Any) -> None: async def goto(self, url: str, **kwargs: Any) -> None:
url = render_template(url) url = render_template(url)
url = prepend_scheme_and_validate_url(url) url = prepend_scheme_and_validate_url(url)

View File

@@ -123,6 +123,40 @@ class SkyvernPage(Page):
) -> str: ) -> str:
return value return value
async def get_totp_digit(
self,
context: Any,
field_name: str,
digit_index: int,
totp_identifier: str | None = None,
totp_url: str | None = None,
) -> str:
"""
Get a specific digit from a TOTP code for multi-field TOTP inputs.
This method is used by generated scripts for multi-field TOTP where each
input field needs a single digit. It resolves the full TOTP code from
the credential and returns the specific digit.
Args:
context: The run context containing parameters
field_name: The parameter name containing the TOTP code or credential reference
digit_index: The index of the digit to return (0-5 for a 6-digit TOTP)
totp_identifier: Optional TOTP identifier for polling
totp_url: Optional TOTP verification URL
Returns:
The single digit at the specified index
"""
# Get the raw parameter value (may be credential reference like BW_TOTP)
raw_value = context.parameters.get(field_name, "")
# Resolve the actual TOTP code (this handles credential generation)
totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url)
# Return the specific digit
if digit_index < len(totp_code):
return totp_code[digit_index]
return ""
######### Public Interfaces ######### ######### Public Interfaces #########
@overload @overload
@@ -409,6 +443,11 @@ class SkyvernPage(Page):
if context and context.ai_mode_override: if context and context.ai_mode_override:
ai = context.ai_mode_override ai = context.ai_mode_override
# For single-digit TOTP values (from multi-field TOTP inputs), force fallback mode
# so that we use the exact digit value instead of having AI generate a new one
if value and len(value) == 1 and value.isdigit() and ai == "proactive":
ai = "fallback"
# format the text with the actual value of the parameter if it's a secret when running a workflow # format the text with the actual value of the parameter if it's a secret when running a workflow
if ai == "fallback": if ai == "fallback":
error_to_raise = None error_to_raise = None