Fix multi-field TOTP support in cached script execution (#4537)
This commit is contained in:
@@ -201,6 +201,10 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool:
|
|||||||
"""
|
"""
|
||||||
Determine whether an input/select action should be forced into proactive mode.
|
Determine whether an input/select action should be forced into proactive mode.
|
||||||
Mirrors runtime logic that treats some inputs as mini-agent flows or TOTP-sensitive.
|
Mirrors runtime logic that treats some inputs as mini-agent flows or TOTP-sensitive.
|
||||||
|
|
||||||
|
NOTE: Multi-field TOTP sequences do NOT require proactive mode because we use
|
||||||
|
get_totp_digit() to provide the exact digit value. Using proactive mode would
|
||||||
|
cause the AI to override our value with its own generated one.
|
||||||
"""
|
"""
|
||||||
if act.get("has_mini_agent", False):
|
if act.get("has_mini_agent", False):
|
||||||
return True
|
return True
|
||||||
@@ -211,12 +215,117 @@ def _requires_mini_agent(act: dict[str, Any]) -> bool:
|
|||||||
# ):
|
# ):
|
||||||
# return True
|
# return True
|
||||||
|
|
||||||
if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"):
|
# Multi-field TOTP sequences should NOT use proactive mode - we provide the
|
||||||
return True
|
# exact digit via get_totp_digit() and want that value used directly
|
||||||
|
# if act.get("totp_timing_info") and act.get("totp_timing_info", {}).get("is_totp_sequence"):
|
||||||
|
# return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _annotate_multi_field_totp_sequence(actions: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Detect and annotate multi-field TOTP sequences in the action list.
|
||||||
|
|
||||||
|
Multi-field TOTP is when a 6-digit code needs to be split across 6 individual input fields.
|
||||||
|
This function identifies such sequences and adds totp_timing_info with the action_index
|
||||||
|
so that each field gets the correct digit (e.g., totp_code[0], totp_code[1], etc.).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: List of actions to analyze and annotate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same actions list with totp_timing_info added to multi-field TOTP actions
|
||||||
|
"""
|
||||||
|
if len(actions) < 4:
|
||||||
|
return actions
|
||||||
|
|
||||||
|
# Identify consecutive runs of single-digit TOTP inputs
|
||||||
|
# A multi-field TOTP sequence is 4+ consecutive INPUT_TEXT actions with single-digit text
|
||||||
|
# and the same field_name (typically 'totp_code')
|
||||||
|
consecutive_start = None
|
||||||
|
consecutive_count = 0
|
||||||
|
totp_field_name = None
|
||||||
|
|
||||||
|
for idx, act in enumerate(actions):
|
||||||
|
is_single_digit_totp = (
|
||||||
|
act.get("action_type") == ActionType.INPUT_TEXT
|
||||||
|
and act.get("field_name")
|
||||||
|
and act.get("text")
|
||||||
|
and len(str(act.get("text", ""))) == 1
|
||||||
|
and str(act.get("text", "")).isdigit()
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_single_digit_totp:
|
||||||
|
current_field = act.get("field_name")
|
||||||
|
if consecutive_start is None:
|
||||||
|
# Start a new sequence
|
||||||
|
consecutive_start = idx
|
||||||
|
totp_field_name = current_field
|
||||||
|
consecutive_count = 1
|
||||||
|
elif current_field == totp_field_name:
|
||||||
|
# Same field, continue the sequence
|
||||||
|
consecutive_count += 1
|
||||||
|
else:
|
||||||
|
# Different field - finalize current sequence if valid, then start new one
|
||||||
|
if consecutive_count >= 4:
|
||||||
|
for seq_idx in range(consecutive_count):
|
||||||
|
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
|
||||||
|
"is_totp_sequence": True,
|
||||||
|
"action_index": seq_idx,
|
||||||
|
"total_digits": consecutive_count,
|
||||||
|
"field_name": totp_field_name,
|
||||||
|
}
|
||||||
|
LOG.debug(
|
||||||
|
"Annotated multi-field TOTP sequence (field change)",
|
||||||
|
start_idx=consecutive_start,
|
||||||
|
count=consecutive_count,
|
||||||
|
field_name=totp_field_name,
|
||||||
|
)
|
||||||
|
# Start new sequence with different field
|
||||||
|
consecutive_start = idx
|
||||||
|
totp_field_name = current_field
|
||||||
|
consecutive_count = 1
|
||||||
|
else:
|
||||||
|
# End of consecutive sequence - check if it was a multi-field TOTP
|
||||||
|
if consecutive_count >= 4 and consecutive_start is not None:
|
||||||
|
# Annotate all actions in this sequence
|
||||||
|
for seq_idx in range(consecutive_count):
|
||||||
|
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
|
||||||
|
"is_totp_sequence": True,
|
||||||
|
"action_index": seq_idx,
|
||||||
|
"total_digits": consecutive_count,
|
||||||
|
"field_name": totp_field_name,
|
||||||
|
}
|
||||||
|
LOG.debug(
|
||||||
|
"Annotated multi-field TOTP sequence for script generation",
|
||||||
|
start_idx=consecutive_start,
|
||||||
|
count=consecutive_count,
|
||||||
|
field_name=totp_field_name,
|
||||||
|
)
|
||||||
|
consecutive_start = None
|
||||||
|
consecutive_count = 0
|
||||||
|
totp_field_name = None
|
||||||
|
|
||||||
|
# Handle sequence at end of actions list
|
||||||
|
if consecutive_count >= 4 and consecutive_start is not None:
|
||||||
|
for seq_idx in range(consecutive_count):
|
||||||
|
actions[consecutive_start + seq_idx]["totp_timing_info"] = {
|
||||||
|
"is_totp_sequence": True,
|
||||||
|
"action_index": seq_idx,
|
||||||
|
"total_digits": consecutive_count,
|
||||||
|
"field_name": totp_field_name,
|
||||||
|
}
|
||||||
|
LOG.debug(
|
||||||
|
"Annotated multi-field TOTP sequence for script generation (at end)",
|
||||||
|
start_idx=consecutive_start,
|
||||||
|
count=consecutive_count,
|
||||||
|
field_name=totp_field_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
def _safe_name(label: str) -> str:
|
def _safe_name(label: str) -> str:
|
||||||
s = "".join(c if c.isalnum() else "_" for c in label).lower()
|
s = "".join(c if c.isalnum() else "_" for c in label).lower()
|
||||||
if not s or s[0].isdigit() or keyword.iskeyword(s):
|
if not s or s[0].isdigit() or keyword.iskeyword(s):
|
||||||
@@ -389,13 +498,32 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output:
|
|||||||
elif method in ["type", "fill"]:
|
elif method in ["type", "fill"]:
|
||||||
# Use context.parameters if field_name is available, otherwise fallback to direct value
|
# Use context.parameters if field_name is available, otherwise fallback to direct value
|
||||||
if act.get("field_name"):
|
if act.get("field_name"):
|
||||||
text_value = cst.Subscript(
|
# Check if this is a multi-field TOTP sequence that needs digit indexing
|
||||||
value=cst.Attribute(
|
totp_info = act.get("totp_timing_info") or {}
|
||||||
value=cst.Name("context"),
|
if totp_info.get("is_totp_sequence") and "action_index" in totp_info:
|
||||||
attr=cst.Name("parameters"),
|
# Generate: await page.get_totp_digit(context, 'field_name', digit_index)
|
||||||
),
|
# This method properly resolves the TOTP code from credentials and returns the specific digit
|
||||||
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
|
text_value = cst.Await(
|
||||||
)
|
expression=cst.Call(
|
||||||
|
func=cst.Attribute(
|
||||||
|
value=cst.Name("page"),
|
||||||
|
attr=cst.Name("get_totp_digit"),
|
||||||
|
),
|
||||||
|
args=[
|
||||||
|
cst.Arg(value=cst.Name("context")),
|
||||||
|
cst.Arg(value=_value(act["field_name"])),
|
||||||
|
cst.Arg(value=_value(totp_info["action_index"])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_value = cst.Subscript(
|
||||||
|
value=cst.Attribute(
|
||||||
|
value=cst.Name("context"),
|
||||||
|
attr=cst.Name("parameters"),
|
||||||
|
),
|
||||||
|
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
text_value = _value(act["text"])
|
text_value = _value(act["text"])
|
||||||
|
|
||||||
@@ -644,6 +772,9 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
|||||||
cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
|
cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
|
||||||
body_stmts: list[cst.BaseStatement] = []
|
body_stmts: list[cst.BaseStatement] = []
|
||||||
|
|
||||||
|
# Detect and annotate multi-field TOTP sequences so each fill gets the correct digit index
|
||||||
|
actions = _annotate_multi_field_totp_sequence(actions)
|
||||||
|
|
||||||
if block.get("url"):
|
if block.get("url"):
|
||||||
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))
|
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import pyotp
|
||||||
import structlog
|
import structlog
|
||||||
|
from cachetools import TTLCache
|
||||||
from playwright.async_api import Page
|
from playwright.async_api import Page
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
@@ -591,6 +593,124 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
# Class-level cache for TOTP codes to ensure all digits in a sequence use the same code
|
||||||
|
# Key: (workflow_run_id, credential_key), Value: totp_code
|
||||||
|
# Uses TTLCache with 30-second expiry (aligned with TOTP rotation period)
|
||||||
|
# and max 100 entries to prevent unbounded memory growth
|
||||||
|
_totp_sequence_cache: TTLCache[tuple[str, str], str] = TTLCache(maxsize=100, ttl=30)
|
||||||
|
|
||||||
|
async def get_totp_digit(
|
||||||
|
self,
|
||||||
|
context: Any,
|
||||||
|
field_name: str,
|
||||||
|
digit_index: int,
|
||||||
|
totp_identifier: str | None = None,
|
||||||
|
totp_url: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get a specific digit from a TOTP code for multi-field TOTP inputs.
|
||||||
|
|
||||||
|
This method is used by generated scripts for multi-field TOTP where each
|
||||||
|
input field needs a single digit. It resolves the full TOTP code from
|
||||||
|
the credential and returns the specific digit.
|
||||||
|
|
||||||
|
IMPORTANT: When digit_index == 0, a fresh TOTP code is generated and cached.
|
||||||
|
For digit_index > 0, the cached code is used. This ensures all 6 digits
|
||||||
|
of a multi-field TOTP use the same code even if filling spans TOTP rotation
|
||||||
|
boundaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The run context containing parameters
|
||||||
|
field_name: The parameter name containing the TOTP code or credential reference
|
||||||
|
digit_index: The index of the digit to return (0-5 for a 6-digit TOTP)
|
||||||
|
totp_identifier: Optional TOTP identifier for polling
|
||||||
|
totp_url: Optional TOTP verification URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The single digit at the specified index
|
||||||
|
"""
|
||||||
|
totp_code = ""
|
||||||
|
skyvern_ctx = skyvern_context.ensure_context()
|
||||||
|
workflow_run_id = skyvern_ctx.workflow_run_id if skyvern_ctx else None
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"get_totp_digit called",
|
||||||
|
field_name=field_name,
|
||||||
|
digit_index=digit_index,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the raw parameter value (may be credential reference like BW_TOTP)
|
||||||
|
raw_value = context.parameters.get(field_name, "")
|
||||||
|
|
||||||
|
# If the direct field_name parameter is empty, try to find a credential TOTP
|
||||||
|
# by looking at the workflow run context for credential parameters
|
||||||
|
if not raw_value and skyvern_ctx and workflow_run_id:
|
||||||
|
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||||
|
if workflow_run_context:
|
||||||
|
# Look for credential parameters in the workflow run context values
|
||||||
|
for key, value in workflow_run_context.values.items():
|
||||||
|
if key.startswith("cred_") and isinstance(value, dict) and "totp" in value:
|
||||||
|
cache_key = (workflow_run_id, key)
|
||||||
|
|
||||||
|
# For digit_index == 0, clear any stale cache and generate fresh TOTP
|
||||||
|
# For digit_index > 0, use cached code if available
|
||||||
|
if digit_index == 0:
|
||||||
|
# Clear stale cache for new sequence, fall through to generate
|
||||||
|
if cache_key in self._totp_sequence_cache:
|
||||||
|
del self._totp_sequence_cache[cache_key]
|
||||||
|
elif cache_key in self._totp_sequence_cache:
|
||||||
|
# Use cached value for digit_index > 0
|
||||||
|
totp_code = self._totp_sequence_cache[cache_key]
|
||||||
|
LOG.info(
|
||||||
|
"Using cached TOTP code for sequence",
|
||||||
|
field_name=field_name,
|
||||||
|
credential_key=key,
|
||||||
|
digit_index=digit_index,
|
||||||
|
totp_code_length=len(totp_code),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Generate new TOTP code (digit_index==0 or cache miss)
|
||||||
|
totp_secret_id = value.get("totp")
|
||||||
|
if totp_secret_id:
|
||||||
|
totp_secret_key = workflow_run_context.totp_secret_value_key(totp_secret_id)
|
||||||
|
totp_secret = workflow_run_context.get_original_secret_value_or_none(totp_secret_key)
|
||||||
|
if totp_secret:
|
||||||
|
try:
|
||||||
|
totp_code = pyotp.TOTP(totp_secret).now()
|
||||||
|
# Cache the code for subsequent digit requests in this sequence
|
||||||
|
self._totp_sequence_cache[cache_key] = totp_code
|
||||||
|
LOG.info(
|
||||||
|
"Generated fresh TOTP and cached for sequence",
|
||||||
|
field_name=field_name,
|
||||||
|
credential_key=key,
|
||||||
|
digit_index=digit_index,
|
||||||
|
totp_code_length=len(totp_code),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(
|
||||||
|
"Failed to generate TOTP code",
|
||||||
|
credential_key=key,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we still don't have a TOTP code, try resolving via get_actual_value
|
||||||
|
if not totp_code:
|
||||||
|
totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url)
|
||||||
|
|
||||||
|
# Return the specific digit
|
||||||
|
if digit_index < len(totp_code):
|
||||||
|
return totp_code[digit_index]
|
||||||
|
LOG.warning(
|
||||||
|
"TOTP digit index out of range",
|
||||||
|
field_name=field_name,
|
||||||
|
digit_index=digit_index,
|
||||||
|
totp_code_length=len(totp_code),
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
async def goto(self, url: str, **kwargs: Any) -> None:
|
async def goto(self, url: str, **kwargs: Any) -> None:
|
||||||
url = render_template(url)
|
url = render_template(url)
|
||||||
url = prepend_scheme_and_validate_url(url)
|
url = prepend_scheme_and_validate_url(url)
|
||||||
|
|||||||
@@ -123,6 +123,40 @@ class SkyvernPage(Page):
|
|||||||
) -> str:
|
) -> str:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
async def get_totp_digit(
|
||||||
|
self,
|
||||||
|
context: Any,
|
||||||
|
field_name: str,
|
||||||
|
digit_index: int,
|
||||||
|
totp_identifier: str | None = None,
|
||||||
|
totp_url: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get a specific digit from a TOTP code for multi-field TOTP inputs.
|
||||||
|
|
||||||
|
This method is used by generated scripts for multi-field TOTP where each
|
||||||
|
input field needs a single digit. It resolves the full TOTP code from
|
||||||
|
the credential and returns the specific digit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The run context containing parameters
|
||||||
|
field_name: The parameter name containing the TOTP code or credential reference
|
||||||
|
digit_index: The index of the digit to return (0-5 for a 6-digit TOTP)
|
||||||
|
totp_identifier: Optional TOTP identifier for polling
|
||||||
|
totp_url: Optional TOTP verification URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The single digit at the specified index
|
||||||
|
"""
|
||||||
|
# Get the raw parameter value (may be credential reference like BW_TOTP)
|
||||||
|
raw_value = context.parameters.get(field_name, "")
|
||||||
|
# Resolve the actual TOTP code (this handles credential generation)
|
||||||
|
totp_code = await self.get_actual_value(raw_value, totp_identifier, totp_url)
|
||||||
|
# Return the specific digit
|
||||||
|
if digit_index < len(totp_code):
|
||||||
|
return totp_code[digit_index]
|
||||||
|
return ""
|
||||||
|
|
||||||
######### Public Interfaces #########
|
######### Public Interfaces #########
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -409,6 +443,11 @@ class SkyvernPage(Page):
|
|||||||
if context and context.ai_mode_override:
|
if context and context.ai_mode_override:
|
||||||
ai = context.ai_mode_override
|
ai = context.ai_mode_override
|
||||||
|
|
||||||
|
# For single-digit TOTP values (from multi-field TOTP inputs), force fallback mode
|
||||||
|
# so that we use the exact digit value instead of having AI generate a new one
|
||||||
|
if value and len(value) == 1 and value.isdigit() and ai == "proactive":
|
||||||
|
ai = "fallback"
|
||||||
|
|
||||||
# format the text with the actual value of the parameter if it's a secret when running a workflow
|
# format the text with the actual value of the parameter if it's a secret when running a workflow
|
||||||
if ai == "fallback":
|
if ai == "fallback":
|
||||||
error_to_raise = None
|
error_to_raise = None
|
||||||
|
|||||||
Reference in New Issue
Block a user