Fix multi-field TOTP support in cached script execution (#4537)
This commit is contained in:
@@ -201,6 +201,10 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Determine whether an input/select action should be forced into proactive mode.
|
||||
Mirrors runtime logic that treats some inputs as mini-agent flows or TOTP-sensitive.
|
||||
|
||||
NOTE: Multi-field TOTP sequences do NOT require proactive mode because we use
|
||||
get_totp_digit() to provide the exact digit value. Using proactive mode would
|
||||
cause the AI to override our value with its own generated one.
|
||||
"""
|
||||
if act.get("has_mini_agent", False):
|
||||
return True
|
||||
@@ -211,12 +215,117 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool:
|
||||
# ):
|
||||
# return True
|
||||
|
||||
if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"):
|
||||
return True
|
||||
# Multi-field TOTP sequences should NOT use proactive mode - we provide the
|
||||
# exact digit via get_totp_digit() and want that value used directly
|
||||
# if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"):
|
||||
# return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _annotate_multi_field_totp_sequence(actions: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detect and annotate multi-field TOTP sequences in the action list.
|
||||
|
||||
Multi-field TOTP is when a 6-digit code needs to be split across 6 individual input fields.
|
||||
This function identifies such sequences and adds totp_timing_info with the action_index
|
||||
so that each field gets the correct digit (e.g., totp_code[0], totp_code[1], etc.).
|
||||
|
||||
Args:
|
||||
actions: List of actions to analyze and annotate
|
||||
|
||||
Returns:
|
||||
The same actions list with totp_timing_info added to multi-field TOTP actions
|
||||
"""
|
||||
if len(actions) < 4:
|
||||
return actions
|
||||
|
||||
# Identify consecutive runs of single-digit TOTP inputs
|
||||
# A multi-field TOTP sequence is 4+ consecutive INPUT_TEXT actions with single-digit text
|
||||
# and the same field_name (typically 'totp_code')
|
||||
consecutive_start = None
|
||||
consecutive_count = 0
|
||||
totp_field_name = None
|
||||
|
||||
for idx, act in enumerate(actions):
|
||||
is_single_digit_totp = (
|
||||
act.get("action_type") == ActionType.INPUT_TEXT
|
||||
and act.get("field_name")
|
||||
and act.get("text")
|
||||
and len(str(act.get("text", ""))) == 1
|
||||
and str(act.get("text", "")).isdigit()
|
||||
)
|
||||
|
||||
if is_single_digit_totp:
|
||||
current_field = act.get("field_name")
|
||||
if consecutive_start is None:
|
||||
# Start a new sequence
|
||||
consecutive_start = idx
|
||||
totp_field_name = current_field
|
||||
consecutive_count = 1
|
||||
elif current_field == totp_field_name:
|
||||
# Same field, continue the sequence
|
||||
consecutive_count += 1
|
||||
else:
|
||||
# Different field - finalize current sequence if valid, then start new one
|
||||
if consecutive_count >= 4:
|
||||
for seq_idx in range(consecutive_count):
|
||||
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
|
||||
"is_totp_sequence": True,
|
||||
"action_index": seq_idx,
|
||||
"total_digits": consecutive_count,
|
||||
"field_name": totp_field_name,
|
||||
}
|
||||
LOG.debug(
|
||||
"Annotated multi-field TOTP sequence (field change)",
|
||||
start_idx=consecutive_start,
|
||||
count=consecutive_count,
|
||||
field_name=totp_field_name,
|
||||
)
|
||||
# Start new sequence with different field
|
||||
consecutive_start = idx
|
||||
totp_field_name = current_field
|
||||
consecutive_count = 1
|
||||
else:
|
||||
# End of consecutive sequence - check if it was a multi-field TOTP
|
||||
if consecutive_count >= 4 and consecutive_start is not None:
|
||||
# Annotate all actions in this sequence
|
||||
for seq_idx in range(consecutive_count):
|
||||
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
|
||||
"is_totp_sequence": True,
|
||||
"action_index": seq_idx,
|
||||
"total_digits": consecutive_count,
|
||||
"field_name": totp_field_name,
|
||||
}
|
||||
LOG.debug(
|
||||
"Annotated multi-field TOTP sequence for script generation",
|
||||
start_idx=consecutive_start,
|
||||
count=consecutive_count,
|
||||
field_name=totp_field_name,
|
||||
)
|
||||
consecutive_start = None
|
||||
consecutive_count = 0
|
||||
totp_field_name = None
|
||||
|
||||
# Handle sequence at end of actions list
|
||||
if consecutive_count >= 4 and consecutive_start is not None:
|
||||
for seq_idx in range(consecutive_count):
|
||||
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
|
||||
"is_totp_sequence": True,
|
||||
"action_index": seq_idx,
|
||||
"total_digits": consecutive_count,
|
||||
"field_name": totp_field_name,
|
||||
}
|
||||
LOG.debug(
|
||||
"Annotated multi-field TOTP sequence for script generation (at end)",
|
||||
start_idx=consecutive_start,
|
||||
count=consecutive_count,
|
||||
field_name=totp_field_name,
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def _safe_name(label: str) -> str:
|
||||
s = "".join(c if c.isalnum() else "_" for c in label).lower()
|
||||
if not s or s[0].isdigit() or keyword.iskeyword(s):
|
||||
@@ -389,13 +498,32 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output:
|
||||
elif method in ["type", "fill"]:
|
||||
# Use context.parameters if field_name is available, otherwise fallback to direct value
|
||||
if act.get("field_name"):
|
||||
text_value = cst.Subscript(
|
||||
value=cst.Attribute(
|
||||
value=cst.Name("context"),
|
||||
attr=cst.Name("parameters"),
|
||||
),
|
||||
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
|
||||
)
|
||||
# Check if this is a multi-field TOTP sequence that needs digit indexing
|
||||
totp_info = act.get("totp_timing_info") or {}
|
||||
if totp_info.get("is_totp_sequence") and "action_index" in totp_info:
|
||||
# Generate: await page.get_totp_digit(context, 'field_name', digit_index)
|
||||
# This method properly resolves the TOTP code from credentials and returns the specific digit
|
||||
text_value = cst.Await(
|
||||
expression=cst.Call(
|
||||
func=cst.Attribute(
|
||||
value=cst.Name("page"),
|
||||
attr=cst.Name("get_totp_digit"),
|
||||
),
|
||||
args=[
|
||||
cst.Arg(value=cst.Name("context")),
|
||||
cst.Arg(value=_value(act["field_name"])),
|
||||
cst.Arg(value=_value(totp_info["action_index"])),
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
text_value = cst.Subscript(
|
||||
value=cst.Attribute(
|
||||
value=cst.Name("context"),
|
||||
attr=cst.Name("parameters"),
|
||||
),
|
||||
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
|
||||
)
|
||||
else:
|
||||
text_value = _value(act["text"])
|
||||
|
||||
@@ -644,6 +772,9 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
||||
cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
|
||||
body_stmts: list[cst.BaseStatement] = []
|
||||
|
||||
# Detect and annotate multi-field TOTP sequences so each fill gets the correct digit index
|
||||
actions = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
if block.get("url"):
|
||||
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
import pyotp
|
||||
import structlog
|
||||
from cachetools import TTLCache
|
||||
from playwright.async_api import Page
|
||||
|
||||
from skyvern.config import settings
|
||||
@@ -591,6 +593,124 @@ class ScriptSkyvernPage(SkyvernPage):
|
||||
|
||||
return value
|
||||
|
||||
# Class-level cache for TOTP codes to ensure all digits in a sequence use the same code
|
||||
# Key: (workflow_run_id, credential_key), Value: totp_code
|
||||
# Uses TTLCache with 30-second expiry (aligned with TOTP rotation period)
|
||||
# and max 100 entries to prevent unbounded memory growth
|
||||
_totp_sequence_cache: TTLCache[tuple[str, str], str] = TTLCache(maxsize=100, ttl=30)
|
||||
|
||||
async def get_totp_digit(
|
||||
self,
|
||||
context: Any,
|
||||
field_name: str,
|
||||
digit_index: int,
|
||||
totp_identifier: str | None = None,
|
||||
totp_url: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get a specific digit from a TOTP code for multi-field TOTP inputs.
|
||||
|
||||
This method is used by generated scripts for multi-field TOTP where each
|
||||
input field needs a single digit. It resolves the full TOTP code from
|
||||
the credential and returns the specific digit.
|
||||
|
||||
IMPORTANT: When digit_index == 0, a fresh TOTP code is generated and cached.
|
||||
For digit_index > 0, the cached code is used. This ensures all 6 digits
|
||||
of a multi-field TOTP use the same code even if filling spans TOTP rotation
|
||||
boundaries.
|
||||
|
||||
Args:
|
||||
context: The run context containing parameters
|
||||
field_name: The parameter name containing the TOTP code or credential reference
|
||||
digit_index: The index of the digit to return (0-5 for a 6-digit TOTP)
|
||||
totp_identifier: Optional TOTP identifier for polling
|
||||
totp_url: Optional TOTP verification URL
|
||||
|
||||
Returns:
|
||||
The single digit at the specified index
|
||||
"""
|
||||
totp_code = ""
|
||||
skyvern_ctx = skyvern_context.ensure_context()
|
||||
workflow_run_id = skyvern_ctx.workflow_run_id if skyvern_ctx else None
|
||||
|
||||
LOG.info(
|
||||
"get_totp_digit called",
|
||||
field_name=field_name,
|
||||
digit_index=digit_index,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
# Get the raw parameter value (may be credential reference like BW_TOTP)
|
||||
raw_value = context.parameters.get(field_name, "")
|
||||
|
||||
# If the direct field_name parameter is empty, try to find a credential TOTP
|
||||
# by looking at the workflow run context for credential parameters
|
||||
if not raw_value and skyvern_ctx and workflow_run_id:
|
||||
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||
if workflow_run_context:
|
||||
# Look for credential parameters in the workflow run context values
|
||||
for key, value in workflow_run_context.values.items():
|
||||
if key.startswith("cred_") and isinstance(value, dict) and "totp" in value:
|
||||
cache_key = (workflow_run_id, key)
|
||||
|
||||
# For digit_index == 0, clear any stale cache and generate fresh TOTP
|
||||
# For digit_index > 0, use cached code if available
|
||||
if digit_index == 0:
|
||||
# Clear stale cache for new sequence, fall through to generate
|
||||
if cache_key in self._totp_sequence_cache:
|
||||
del self._totp_sequence_cache[cache_key]
|
||||
elif cache_key in self._totp_sequence_cache:
|
||||
# Use cached value for digit_index > 0
|
||||
totp_code = self._totp_sequence_cache[cache_key]
|
||||
LOG.info(
|
||||
"Using cached TOTP code for sequence",
|
||||
field_name=field_name,
|
||||
credential_key=key,
|
||||
digit_index=digit_index,
|
||||
totp_code_length=len(totp_code),
|
||||
)
|
||||
break
|
||||
|
||||
# Generate new TOTP code (digit_index==0 or cache miss)
|
||||
totp_secret_id = value.get("totp")
|
||||
if totp_secret_id:
|
||||
totp_secret_key = workflow_run_context.totp_secret_value_key(totp_secret_id)
|
||||
totp_secret = workflow_run_context.get_original_secret_value_or_none(totp_secret_key)
|
||||
if totp_secret:
|
||||
try:
|
||||
totp_code = pyotp.TOTP(totp_secret).now()
|
||||
# Cache the code for subsequent digit requests in this sequence
|
||||
self._totp_sequence_cache[cache_key] = totp_code
|
||||
LOG.info(
|
||||
"Generated fresh TOTP and cached for sequence",
|
||||
field_name=field_name,
|
||||
credential_key=key,
|
||||
digit_index=digit_index,
|
||||
totp_code_length=len(totp_code),
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
LOG.warning(
|
||||
"Failed to generate TOTP code",
|
||||
credential_key=key,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# If we still don't have a TOTP code, try resolving via get_actual_value
|
||||
if not totp_code:
|
||||
totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url)
|
||||
|
||||
# Return the specific digit
|
||||
if digit_index < len(totp_code):
|
||||
return totp_code[digit_index]
|
||||
LOG.warning(
|
||||
"TOTP digit index out of range",
|
||||
field_name=field_name,
|
||||
digit_index=digit_index,
|
||||
totp_code_length=len(totp_code),
|
||||
)
|
||||
return ""
|
||||
|
||||
async def goto(self, url: str, **kwargs: Any) -> None:
|
||||
url = render_template(url)
|
||||
url = prepend_scheme_and_validate_url(url)
|
||||
|
||||
@@ -123,6 +123,40 @@ class SkyvernPage(Page):
|
||||
) -> str:
|
||||
return value
|
||||
|
||||
async def get_totp_digit(
|
||||
self,
|
||||
context: Any,
|
||||
field_name: str,
|
||||
digit_index: int,
|
||||
totp_identifier: str | None = None,
|
||||
totp_url: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get a specific digit from a TOTP code for multi-field TOTP inputs.
|
||||
|
||||
This method is used by generated scripts for multi-field TOTP where each
|
||||
input field needs a single digit. It resolves the full TOTP code from
|
||||
the credential and returns the specific digit.
|
||||
|
||||
Args:
|
||||
context: The run context containing parameters
|
||||
field_name: The parameter name containing the TOTP code or credential reference
|
||||
digit_index: The index of the digit to return (0-5 for a 6-digit TOTP)
|
||||
totp_identifier: Optional TOTP identifier for polling
|
||||
totp_url: Optional TOTP verification URL
|
||||
|
||||
Returns:
|
||||
The single digit at the specified index
|
||||
"""
|
||||
# Get the raw parameter value (may be credential reference like BW_TOTP)
|
||||
raw_value = context.parameters.get(field_name, "")
|
||||
# Resolve the actual TOTP code (this handles credential generation)
|
||||
totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url)
|
||||
# Return the specific digit
|
||||
if digit_index < len(totp_code):
|
||||
return totp_code[digit_index]
|
||||
return ""
|
||||
|
||||
######### Public Interfaces #########
|
||||
|
||||
@overload
|
||||
@@ -409,6 +443,11 @@ class SkyvernPage(Page):
|
||||
if context and context.ai_mode_override:
|
||||
ai = context.ai_mode_override
|
||||
|
||||
# For single-digit TOTP values (from multi-field TOTP inputs), force fallback mode
|
||||
# so that we use the exact digit value instead of having AI generate a new one
|
||||
if value and len(value) == 1 and value.isdigit() and ai == "proactive":
|
||||
ai = "fallback"
|
||||
|
||||
# format the text with the actual value of the parameter if it's a secret when running a workflow
|
||||
if ai == "fallback":
|
||||
error_to_raise = None
|
||||
|
||||
Reference in New Issue
Block a user