Files
Dorod-Sky/skyvern/services/otp_service.py
2026-02-18 23:01:59 -05:00

371 lines
14 KiB
Python

import asyncio
from datetime import datetime, timedelta
import pyotp
import structlog
from pydantic import BaseModel, Field
from skyvern.config import settings
from skyvern.exceptions import FailedToGetTOTPVerificationCode, NoTOTPVerificationCodeFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.schemas.totp_codes import OTPType
LOG = structlog.get_logger()
class OTPValue(BaseModel):
value: str = Field(..., description="The value of the OTP code.")
type: OTPType | None = Field(None, description="The type of the OTP code.")
def get_otp_type(self) -> OTPType:
if self.type:
return self.type
value = self.value.strip().lower()
if value.startswith("https://") or value.startswith("http://"):
return OTPType.MAGIC_LINK
return OTPType.TOTP
class OTPResultParsedByLLM(BaseModel):
reasoning: str = Field(..., description="The reasoning of the OTP code.")
otp_type: OTPType | None = Field(None, description="The type of the OTP code.")
otp_value_found: bool = Field(..., description="Whether the OTP value is found.")
otp_value: str | None = Field(None, description="The OTP value.")
async def parse_otp_login(
content: str,
organization_id: str,
enforced_otp_type: OTPType | None = None,
) -> OTPValue | None:
prompt = prompt_engine.load_prompt(
"parse-otp-login",
content=content,
enforced_otp_type=enforced_otp_type.value if enforced_otp_type else None,
)
resp = await app.SECONDARY_LLM_API_HANDLER(
prompt=prompt, prompt_name="parse-otp-login", organization_id=organization_id
)
LOG.info("OTP Login Parser Response", resp=resp, enforced_otp_type=enforced_otp_type)
otp_result = OTPResultParsedByLLM.model_validate(resp)
if otp_result.otp_value_found and otp_result.otp_value:
return OTPValue(value=otp_result.otp_value, type=otp_result.otp_type)
return None
def try_generate_totp_from_credential(workflow_run_id: str | None) -> OTPValue | None:
"""Try to generate a TOTP code from a credential secret stored in the workflow run context.
Scans workflow_run_context.values for credential entries with a "totp" key
(e.g. Bitwarden, 1Password, Azure Key Vault credentials) and generates a
TOTP code using pyotp. This should be checked BEFORE poll_otp_value so that
credential-based TOTP takes priority over webhook (totp_url) and totp_identifier.
"""
if not workflow_run_id:
return None
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
if not workflow_run_context:
return None
for key, value in workflow_run_context.values.items():
if isinstance(value, dict) and "totp" in value:
totp_secret_id = value.get("totp")
if not totp_secret_id or not isinstance(totp_secret_id, str):
continue
totp_secret_key = workflow_run_context.totp_secret_value_key(totp_secret_id)
totp_secret = workflow_run_context.get_original_secret_value_or_none(totp_secret_key)
if totp_secret:
try:
code = pyotp.TOTP(totp_secret).now()
LOG.info(
"Generated TOTP from credential secret",
workflow_run_id=workflow_run_id,
credential_key=key,
)
return OTPValue(value=code, type=OTPType.TOTP)
except Exception:
LOG.warning(
"Failed to generate TOTP from credential secret",
workflow_run_id=workflow_run_id,
credential_key=key,
exc_info=True,
)
return None
async def poll_otp_value(
organization_id: str,
task_id: str | None = None,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
workflow_permanent_id: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
) -> OTPValue | None:
timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
start_datetime = datetime.utcnow()
timeout_datetime = start_datetime + timeout
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value)
if not org_token:
LOG.error("Failed to get organization token when trying to get otp value")
return None
LOG.info(
"Polling otp value",
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_permanent_id=workflow_permanent_id,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
# Set the waiting state in the database when polling starts
identifier_for_ui = totp_identifier
if workflow_run_id:
try:
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
waiting_for_verification_code=True,
verification_code_identifier=identifier_for_ui,
verification_code_polling_started_at=start_datetime,
)
LOG.info(
"Set 2FA waiting state for workflow run",
workflow_run_id=workflow_run_id,
verification_code_identifier=identifier_for_ui,
)
try:
NotificationRegistryFactory.get_registry().publish(
organization_id,
{
"type": "verification_code_required",
"workflow_run_id": workflow_run_id,
"task_id": task_id,
"identifier": identifier_for_ui,
"polling_started_at": start_datetime.isoformat(),
},
)
except Exception:
LOG.warning("Failed to publish 2FA required notification for workflow run", exc_info=True)
except Exception:
LOG.warning("Failed to set 2FA waiting state for workflow run", exc_info=True)
elif task_id:
try:
await app.DATABASE.update_task_2fa_state(
task_id=task_id,
organization_id=organization_id,
waiting_for_verification_code=True,
verification_code_identifier=identifier_for_ui,
verification_code_polling_started_at=start_datetime,
)
LOG.info(
"Set 2FA waiting state for task",
task_id=task_id,
verification_code_identifier=identifier_for_ui,
)
try:
NotificationRegistryFactory.get_registry().publish(
organization_id,
{
"type": "verification_code_required",
"task_id": task_id,
"identifier": identifier_for_ui,
"polling_started_at": start_datetime.isoformat(),
},
)
except Exception:
LOG.warning("Failed to publish 2FA required notification for task", exc_info=True)
except Exception:
LOG.warning("Failed to set 2FA waiting state for task", exc_info=True)
try:
while True:
await asyncio.sleep(10)
# check timeout
if datetime.utcnow() > timeout_datetime:
LOG.warning("Polling otp value timed out")
raise NoTOTPVerificationCodeFound(
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_permanent_id,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
otp_value: OTPValue | None = None
if totp_verification_url:
otp_value = await _get_otp_value_from_url(
organization_id,
totp_verification_url,
org_token.token,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
elif totp_identifier:
otp_value = await _get_otp_value_from_db(
organization_id,
totp_identifier,
task_id=task_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
)
if not otp_value:
otp_value = await _get_otp_value_by_run(
organization_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
else:
# No pre-configured TOTP — poll for manually submitted codes by run context
otp_value = await _get_otp_value_by_run(
organization_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
if otp_value:
LOG.info("Got otp value", otp_value=otp_value)
return otp_value
finally:
# Clear the waiting state when polling completes (success, timeout, or error)
if workflow_run_id:
try:
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
waiting_for_verification_code=False,
)
LOG.info("Cleared 2FA waiting state for workflow run", workflow_run_id=workflow_run_id)
try:
NotificationRegistryFactory.get_registry().publish(
organization_id,
{"type": "verification_code_resolved", "workflow_run_id": workflow_run_id, "task_id": task_id},
)
except Exception:
LOG.warning("Failed to publish 2FA resolved notification for workflow run", exc_info=True)
except Exception:
LOG.warning("Failed to clear 2FA waiting state for workflow run", exc_info=True)
elif task_id:
try:
await app.DATABASE.update_task_2fa_state(
task_id=task_id,
organization_id=organization_id,
waiting_for_verification_code=False,
)
LOG.info("Cleared 2FA waiting state for task", task_id=task_id)
try:
NotificationRegistryFactory.get_registry().publish(
organization_id,
{"type": "verification_code_resolved", "task_id": task_id},
)
except Exception:
LOG.warning("Failed to publish 2FA resolved notification for task", exc_info=True)
except Exception:
LOG.warning("Failed to clear 2FA waiting state for task", exc_info=True)
async def _get_otp_value_from_url(
organization_id: str,
url: str,
api_key: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_permanent_id: str | None = None,
) -> OTPValue | None:
request_data = {}
if task_id:
request_data["task_id"] = task_id
if workflow_run_id:
request_data["workflow_run_id"] = workflow_run_id
if workflow_permanent_id:
request_data["workflow_permanent_id"] = workflow_permanent_id
signed_data = generate_skyvern_webhook_signature(
payload=request_data,
api_key=api_key,
)
try:
json_resp = await aiohttp_post(
url=url,
str_data=signed_data.signed_payload,
headers=signed_data.headers,
raise_exception=False,
retry=2,
retry_timeout=5,
)
except Exception as e:
LOG.error("Failed to get otp value from url", exc_info=True)
raise FailedToGetTOTPVerificationCode(
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_permanent_id,
totp_verification_url=url,
reason=str(e),
)
if not json_resp:
return None
content = json_resp.get("verification_code", None)
if not content:
return None
otp_value: OTPValue | None = OTPValue(value=content, type=OTPType.TOTP)
if isinstance(content, str) and len(content) > 10:
try:
otp_value = await parse_otp_login(content, organization_id)
except Exception:
LOG.warning("faile to parse content by LLM call", exc_info=True)
if not otp_value:
LOG.warning(
"Failed to parse otp login from the totp url",
content=content,
)
return None
return otp_value
async def _get_otp_value_by_run(
organization_id: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
) -> OTPValue | None:
"""Look up OTP codes by task_id/workflow_run_id when no totp_identifier is configured.
Used for the manual 2FA input flow where users submit codes through the UI
without pre-configured TOTP credentials.
"""
codes = await app.DATABASE.get_otp_codes_by_run(
organization_id=organization_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
limit=1,
)
if codes:
code = codes[0]
return OTPValue(value=code.code, type=code.otp_type)
return None
async def _get_otp_value_from_db(
organization_id: str,
totp_identifier: str,
task_id: str | None = None,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
) -> OTPValue | None:
totp_codes = await app.DATABASE.get_otp_codes(organization_id=organization_id, totp_identifier=totp_identifier)
for totp_code in totp_codes:
if totp_code.workflow_run_id and workflow_run_id and totp_code.workflow_run_id != workflow_run_id:
continue
if totp_code.workflow_id and workflow_id and totp_code.workflow_id != workflow_id:
continue
if totp_code.task_id and totp_code.task_id != task_id:
continue
if totp_code.expired_at and totp_code.expired_at < datetime.utcnow():
continue
return OTPValue(value=totp_code.code, type=totp_code.otp_type)
return None