Files
Dorod-Sky/skyvern/services/otp_service.py
2025-11-04 11:29:14 +08:00

182 lines
6.6 KiB
Python

import asyncio
from datetime import datetime, timedelta
import structlog
from pydantic import BaseModel, Field
from skyvern.config import settings
from skyvern.exceptions import FailedToGetTOTPVerificationCode, NoTOTPVerificationCodeFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.schemas.totp_codes import OTPType
LOG = structlog.get_logger()
class OTPValue(BaseModel):
value: str = Field(..., description="The value of the OTP code.")
type: OTPType | None = Field(None, description="The type of the OTP code.")
def get_otp_type(self) -> OTPType:
if self.type:
return self.type
value = self.value.strip().lower()
if value.startswith("https://") or value.startswith("http://"):
return OTPType.MAGIC_LINK
return OTPType.TOTP
class OTPResultParsedByLLM(BaseModel):
reasoning: str = Field(..., description="The reasoning of the OTP code.")
otp_type: OTPType | None = Field(None, description="The type of the OTP code.")
otp_value_found: bool = Field(..., description="Whether the OTP value is found.")
otp_value: str | None = Field(None, description="The OTP value.")
async def parse_otp_login(content: str, organization_id: str) -> OTPValue | None:
prompt = prompt_engine.load_prompt("parse-otp-login", content=content)
resp = await app.SECONDARY_LLM_API_HANDLER(
prompt=prompt, prompt_name="parse-otp-login", organization_id=organization_id
)
LOG.info("OTP Login Parser Response", resp=resp)
otp_result = OTPResultParsedByLLM.model_validate(resp)
if otp_result.otp_value_found and otp_result.otp_value:
return OTPValue(value=otp_result.otp_value, type=otp_result.otp_type)
return None
async def poll_otp_value(
organization_id: str,
task_id: str | None = None,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
workflow_permanent_id: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
) -> OTPValue | None:
timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
start_datetime = datetime.utcnow()
timeout_datetime = start_datetime + timeout
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value)
if not org_token:
LOG.error("Failed to get organization token when trying to get otp value")
return None
LOG.info(
"Polling otp value",
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_permanent_id=workflow_permanent_id,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
while True:
await asyncio.sleep(10)
# check timeout
if datetime.utcnow() > timeout_datetime:
LOG.warning("Polling otp value timed out")
raise NoTOTPVerificationCodeFound(
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_permanent_id,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
otp_value: OTPValue | None = None
if totp_verification_url:
otp_value = await _get_otp_value_from_url(
organization_id,
totp_verification_url,
org_token.token,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
elif totp_identifier:
otp_value = await _get_otp_value_from_db(
organization_id,
totp_identifier,
task_id=task_id,
workflow_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
)
if otp_value:
LOG.info("Got otp value", otp_value=otp_value)
return otp_value
async def _get_otp_value_from_url(
organization_id: str,
url: str,
api_key: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_permanent_id: str | None = None,
) -> OTPValue | None:
request_data = {}
if task_id:
request_data["task_id"] = task_id
if workflow_run_id:
request_data["workflow_run_id"] = workflow_run_id
if workflow_permanent_id:
request_data["workflow_permanent_id"] = workflow_permanent_id
signed_data = generate_skyvern_webhook_signature(
payload=request_data,
api_key=api_key,
)
try:
json_resp = await aiohttp_post(url=url, data=request_data, headers=signed_data.headers, raise_exception=False)
except Exception as e:
LOG.error("Failed to get otp value from url", exc_info=True)
raise FailedToGetTOTPVerificationCode(
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_permanent_id,
totp_verification_url=url,
reason=str(e),
)
if not json_resp:
return None
content = json_resp.get("verification_code", None)
if not content:
return None
otp_value: OTPValue | None = OTPValue(value=content, type=OTPType.TOTP)
if isinstance(content, str) and len(content) > 10:
try:
otp_value = await parse_otp_login(content, organization_id)
except Exception:
LOG.warning("faile to parse content by LLM call", exc_info=True)
if not otp_value:
LOG.warning(
"Failed to parse otp login from the totp url",
content=content,
)
return None
return otp_value
async def _get_otp_value_from_db(
organization_id: str,
totp_identifier: str,
task_id: str | None = None,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
) -> OTPValue | None:
totp_codes = await app.DATABASE.get_otp_codes(organization_id=organization_id, totp_identifier=totp_identifier)
for totp_code in totp_codes:
if totp_code.workflow_run_id and workflow_run_id and totp_code.workflow_run_id != workflow_run_id:
continue
if totp_code.workflow_id and workflow_id and totp_code.workflow_id != workflow_id:
continue
if totp_code.task_id and totp_code.task_id != task_id:
continue
if totp_code.expired_at and totp_code.expired_at < datetime.utcnow():
continue
return OTPValue(value=totp_code.code, type=totp_code.otp_type)
return None