add totp_code_required info to the InputTextAction json (#3415)

This commit is contained in:
Shuchang Zheng
2025-09-11 19:05:25 -07:00
committed by GitHub
parent 0e2aecc75d
commit f1aa653b82
8 changed files with 227 additions and 120 deletions

View File

@@ -269,7 +269,7 @@ def _make_decorator(block_label: str, block: dict[str, Any]) -> cst.Decorator:
)
def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.BaseStatement:
def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output: bool = False) -> cst.BaseStatement:
"""
Turn one Action dict into:
@@ -327,6 +327,29 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
),
)
)
if act.get("totp_code_required"):
if task.get("totp_identifier"):
args.append(
cst.Arg(
keyword=cst.Name("totp_identifier"),
value=cst.Name(task.get("totp_identifier")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
if task.get("totp_url"):
args.append(
cst.Arg(
keyword=cst.Name("totp_url"),
value=cst.Name(task.get("totp_verification_url")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
elif method == "select_option":
args.append(
cst.Arg(
@@ -431,7 +454,7 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
# For extraction blocks, assign extract action results to output variable
assign_to_output = is_extraction_block and act["action_type"] == "extract"
body_stmts.append(_action_to_stmt(act, assign_to_output=assign_to_output))
body_stmts.append(_action_to_stmt(act, block, assign_to_output=assign_to_output))
# For extraction blocks, add return output statement if we have actions
if is_extraction_block and any(

View File

@@ -12,6 +12,8 @@ import structlog
from playwright.async_api import Page
from skyvern.config import settings
from skyvern.constants import SPECIAL_FIELD_VERIFICATION_CODE
from skyvern.core.totp import poll_verification_code
from skyvern.exceptions import WorkflowRunNotFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
@@ -351,8 +353,19 @@ class SkyvernPage:
intention: str | None = None,
data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
totp_identifier: str | None = None,
totp_url: str | None = None,
) -> str:
return await self._input_text(xpath, value, ai_infer, intention, data, timeout)
return await self._input_text(
xpath=xpath,
value=value,
ai_infer=ai_infer,
intention=intention,
data=data,
timeout=timeout,
totp_identifier=totp_identifier,
totp_url=totp_url,
)
@action_wrap(ActionType.INPUT_TEXT)
async def type(
@@ -363,8 +376,19 @@ class SkyvernPage:
intention: str | None = None,
data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
totp_identifier: str | None = None,
totp_url: str | None = None,
) -> str:
return await self._input_text(xpath, value, ai_infer, intention, data, timeout)
return await self._input_text(
xpath=xpath,
value=value,
ai_infer=ai_infer,
intention=intention,
data=data,
timeout=timeout,
totp_identifier=totp_identifier,
totp_url=totp_url,
)
async def _input_text(
self,
@@ -374,6 +398,8 @@ class SkyvernPage:
intention: str | None = None,
data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
totp_identifier: str | None = None,
totp_url: str | None = None,
) -> str:
"""Input text into an element identified by ``xpath``.
@@ -395,6 +421,24 @@ class SkyvernPage:
# Build the element tree of the current page for the prompt
# clean up empty data values
data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "")
if (totp_identifier or totp_url) and context and context.organization_id and context.task_id:
verification_code = await poll_verification_code(
organization_id=context.organization_id,
task_id=context.task_id,
workflow_run_id=context.workflow_run_id,
totp_identifier=totp_identifier,
totp_verification_url=totp_url,
)
if verification_code:
if isinstance(data, dict) and SPECIAL_FIELD_VERIFICATION_CODE not in data:
data[SPECIAL_FIELD_VERIFICATION_CODE] = verification_code
elif isinstance(data, str) and SPECIAL_FIELD_VERIFICATION_CODE not in data:
data = f"{data}\n" + str({SPECIAL_FIELD_VERIFICATION_CODE: verification_code})
elif isinstance(data, list):
data.append({SPECIAL_FIELD_VERIFICATION_CODE: verification_code})
else:
data = {SPECIAL_FIELD_VERIFICATION_CODE: verification_code}
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt(
template="script-generation-input-text-generatiion",

113
skyvern/core/totp.py Normal file
View File

@@ -0,0 +1,113 @@
import asyncio
import json
from datetime import datetime, timedelta
import structlog
from skyvern.config import settings
from skyvern.exceptions import NoTOTPVerificationCodeFound
from skyvern.forge import app
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
LOG = structlog.get_logger()
async def poll_verification_code(
organization_id: str,
task_id: str | None = None,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
workflow_permanent_id: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
) -> str | None:
timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
start_datetime = datetime.utcnow()
timeout_datetime = start_datetime + timeout
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api)
if not org_token:
LOG.error("Failed to get organization token when trying to get verification code")
return None
while True:
await asyncio.sleep(10)
# check timeout
if datetime.utcnow() > timeout_datetime:
LOG.warning("Polling verification code timed out")
raise NoTOTPVerificationCodeFound(
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_permanent_id,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
verification_code = None
if totp_verification_url:
verification_code = await _get_verification_code_from_url(
totp_verification_url,
org_token.token,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
elif totp_identifier:
verification_code = await _get_verification_code_from_db(
organization_id,
totp_identifier,
task_id=task_id,
workflow_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
)
if verification_code:
LOG.info("Got verification code", verification_code=verification_code)
return verification_code
async def _get_verification_code_from_url(
url: str,
api_key: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_permanent_id: str | None = None,
) -> str | None:
request_data = {}
if task_id:
request_data["task_id"] = task_id
if workflow_run_id:
request_data["workflow_run_id"] = workflow_run_id
if workflow_permanent_id:
request_data["workflow_permanent_id"] = workflow_permanent_id
payload = json.dumps(request_data)
signature = generate_skyvern_signature(
payload=payload,
api_key=api_key,
)
timestamp = str(int(datetime.utcnow().timestamp()))
headers = {
"x-skyvern-timestamp": timestamp,
"x-skyvern-signature": signature,
"Content-Type": "application/json",
}
json_resp = await aiohttp_post(url=url, data=request_data, headers=headers, raise_exception=False)
return json_resp.get("verification_code", None)
async def _get_verification_code_from_db(
organization_id: str,
totp_identifier: str,
task_id: str | None = None,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
) -> str | None:
totp_codes = await app.DATABASE.get_totp_codes(organization_id=organization_id, totp_identifier=totp_identifier)
for totp_code in totp_codes:
if totp_code.workflow_run_id and workflow_run_id and totp_code.workflow_run_id != workflow_run_id:
continue
if totp_code.workflow_id and workflow_id and totp_code.workflow_id != workflow_id:
continue
if totp_code.task_id and totp_code.task_id != task_id:
continue
if totp_code.expired_at and totp_code.expired_at < datetime.utcnow():
continue
return totp_code.code
return None