import asyncio from datetime import datetime, timedelta import structlog from pydantic import BaseModel, Field from skyvern.config import settings from skyvern.exceptions import FailedToGetTOTPVerificationCode, NoTOTPVerificationCodeFound from skyvern.forge import app from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory from skyvern.forge.sdk.schemas.totp_codes import OTPType LOG = structlog.get_logger() class OTPValue(BaseModel): value: str = Field(..., description="The value of the OTP code.") type: OTPType | None = Field(None, description="The type of the OTP code.") def get_otp_type(self) -> OTPType: if self.type: return self.type value = self.value.strip().lower() if value.startswith("https://") or value.startswith("http://"): return OTPType.MAGIC_LINK return OTPType.TOTP class OTPResultParsedByLLM(BaseModel): reasoning: str = Field(..., description="The reasoning of the OTP code.") otp_type: OTPType | None = Field(None, description="The type of the OTP code.") otp_value_found: bool = Field(..., description="Whether the OTP value is found.") otp_value: str | None = Field(None, description="The OTP value.") async def parse_otp_login( content: str, organization_id: str, enforced_otp_type: OTPType | None = None, ) -> OTPValue | None: prompt = prompt_engine.load_prompt( "parse-otp-login", content=content, enforced_otp_type=enforced_otp_type.value if enforced_otp_type else None, ) resp = await app.SECONDARY_LLM_API_HANDLER( prompt=prompt, prompt_name="parse-otp-login", organization_id=organization_id ) LOG.info("OTP Login Parser Response", resp=resp, enforced_otp_type=enforced_otp_type) otp_result = OTPResultParsedByLLM.model_validate(resp) if otp_result.otp_value_found and otp_result.otp_value: return OTPValue(value=otp_result.otp_value, type=otp_result.otp_type) return None async def poll_otp_value( organization_id: str, task_id: str | None = None, workflow_id: str | None = None, workflow_run_id: str | None = None, workflow_permanent_id: str | None = None, totp_verification_url: str | None = None, totp_identifier: str | None = None, ) -> OTPValue | None: timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS) start_datetime = datetime.utcnow() timeout_datetime = start_datetime + timeout org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value) if not org_token: LOG.error("Failed to get organization token when trying to get otp value") return None LOG.info( "Polling otp value", task_id=task_id, workflow_run_id=workflow_run_id, workflow_permanent_id=workflow_permanent_id, totp_verification_url=totp_verification_url, totp_identifier=totp_identifier, ) # Set the waiting state in the database when polling starts identifier_for_ui = totp_identifier if workflow_run_id: try: await app.DATABASE.update_workflow_run( workflow_run_id=workflow_run_id, waiting_for_verification_code=True, verification_code_identifier=identifier_for_ui, verification_code_polling_started_at=start_datetime, ) LOG.info( "Set 2FA waiting state for workflow run", workflow_run_id=workflow_run_id, verification_code_identifier=identifier_for_ui, ) try: NotificationRegistryFactory.get_registry().publish( organization_id, { "type": "verification_code_required", "workflow_run_id": workflow_run_id, "task_id": task_id, "identifier": identifier_for_ui, "polling_started_at": start_datetime.isoformat(), }, ) except Exception: LOG.warning("Failed to publish 2FA required notification for workflow run", exc_info=True) except Exception: LOG.warning("Failed to set 2FA waiting state for workflow run", exc_info=True) elif task_id: try: await app.DATABASE.update_task_2fa_state( task_id=task_id, organization_id=organization_id, waiting_for_verification_code=True, verification_code_identifier=identifier_for_ui, verification_code_polling_started_at=start_datetime, ) LOG.info( "Set 2FA waiting state for task", task_id=task_id, verification_code_identifier=identifier_for_ui, ) try: NotificationRegistryFactory.get_registry().publish( organization_id, { "type": "verification_code_required", "task_id": task_id, "identifier": identifier_for_ui, "polling_started_at": start_datetime.isoformat(), }, ) except Exception: LOG.warning("Failed to publish 2FA required notification for task", exc_info=True) except Exception: LOG.warning("Failed to set 2FA waiting state for task", exc_info=True) try: while True: await asyncio.sleep(10) # check timeout if datetime.utcnow() > timeout_datetime: LOG.warning("Polling otp value timed out") raise NoTOTPVerificationCodeFound( task_id=task_id, workflow_run_id=workflow_run_id, workflow_id=workflow_permanent_id, totp_verification_url=totp_verification_url, totp_identifier=totp_identifier, ) otp_value: OTPValue | None = None if totp_verification_url: otp_value = await _get_otp_value_from_url( organization_id, totp_verification_url, org_token.token, task_id=task_id, workflow_run_id=workflow_run_id, ) elif totp_identifier: otp_value = await _get_otp_value_from_db( organization_id, totp_identifier, task_id=task_id, workflow_id=workflow_id, workflow_run_id=workflow_run_id, ) if not otp_value: otp_value = await _get_otp_value_by_run( organization_id, task_id=task_id, workflow_run_id=workflow_run_id, ) else: # No pre-configured TOTP — poll for manually submitted codes by run context otp_value = await _get_otp_value_by_run( organization_id, task_id=task_id, workflow_run_id=workflow_run_id, ) if otp_value: LOG.info("Got otp value", otp_value=otp_value) return otp_value finally: # Clear the waiting state when polling completes (success, timeout, or error) if workflow_run_id: try: await app.DATABASE.update_workflow_run( workflow_run_id=workflow_run_id, waiting_for_verification_code=False, ) LOG.info("Cleared 2FA waiting state for workflow run", workflow_run_id=workflow_run_id) try: NotificationRegistryFactory.get_registry().publish( organization_id, {"type": "verification_code_resolved", "workflow_run_id": workflow_run_id, "task_id": task_id}, ) except Exception: LOG.warning("Failed to publish 2FA resolved notification for workflow run", exc_info=True) except Exception: LOG.warning("Failed to clear 2FA waiting state for workflow run", exc_info=True) elif task_id: try: await app.DATABASE.update_task_2fa_state( task_id=task_id, organization_id=organization_id, waiting_for_verification_code=False, ) LOG.info("Cleared 2FA waiting state for task", task_id=task_id) try: NotificationRegistryFactory.get_registry().publish( organization_id, {"type": "verification_code_resolved", "task_id": task_id}, ) except Exception: LOG.warning("Failed to publish 2FA resolved notification for task", exc_info=True) except Exception: LOG.warning("Failed to clear 2FA waiting state for task", exc_info=True) async def _get_otp_value_from_url( organization_id: str, url: str, api_key: str, task_id: str | None = None, workflow_run_id: str | None = None, workflow_permanent_id: str | None = None, ) -> OTPValue | None: request_data = {} if task_id: request_data["task_id"] = task_id if workflow_run_id: request_data["workflow_run_id"] = workflow_run_id if workflow_permanent_id: request_data["workflow_permanent_id"] = workflow_permanent_id signed_data = generate_skyvern_webhook_signature( payload=request_data, api_key=api_key, ) try: json_resp = await aiohttp_post( url=url, str_data=signed_data.signed_payload, headers=signed_data.headers, raise_exception=False, retry=2, retry_timeout=5, ) except Exception as e: LOG.error("Failed to get otp value from url", exc_info=True) raise FailedToGetTOTPVerificationCode( task_id=task_id, workflow_run_id=workflow_run_id, workflow_id=workflow_permanent_id, totp_verification_url=url, reason=str(e), ) if not json_resp: return None content = json_resp.get("verification_code", None) if not content: return None otp_value: OTPValue | None = OTPValue(value=content, type=OTPType.TOTP) if isinstance(content, str) and len(content) > 10: try: otp_value = await parse_otp_login(content, organization_id) except Exception: LOG.warning("faile to parse content by LLM call", exc_info=True) if not otp_value: LOG.warning( "Failed to parse otp login from the totp url", content=content, ) return None return otp_value async def _get_otp_value_by_run( organization_id: str, task_id: str | None = None, workflow_run_id: str | None = None, ) -> OTPValue | None: """Look up OTP codes by task_id/workflow_run_id when no totp_identifier is configured. Used for the manual 2FA input flow where users submit codes through the UI without pre-configured TOTP credentials. """ codes = await app.DATABASE.get_otp_codes_by_run( organization_id=organization_id, task_id=task_id, workflow_run_id=workflow_run_id, limit=1, ) if codes: code = codes[0] return OTPValue(value=code.code, type=code.otp_type) return None async def _get_otp_value_from_db( organization_id: str, totp_identifier: str, task_id: str | None = None, workflow_id: str | None = None, workflow_run_id: str | None = None, ) -> OTPValue | None: totp_codes = await app.DATABASE.get_otp_codes(organization_id=organization_id, totp_identifier=totp_identifier) for totp_code in totp_codes: if totp_code.workflow_run_id and workflow_run_id and totp_code.workflow_run_id != workflow_run_id: continue if totp_code.workflow_id and workflow_id and totp_code.workflow_id != workflow_id: continue if totp_code.task_id and totp_code.task_id != task_id: continue if totp_code.expired_at and totp_code.expired_at < datetime.utcnow(): continue return OTPValue(value=totp_code.code, type=totp_code.otp_type) return None