support magic link login (#3702)
This commit is contained in:
173
skyvern/services/otp_service.py
Normal file
173
skyvern/services/otp_service.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import FailedToGetTOTPVerificationCode, NoTOTPVerificationCodeFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class OTPValue(BaseModel):
|
||||
value: str = Field(..., description="The value of the OTP code.")
|
||||
type: OTPType | None = Field(None, description="The type of the OTP code.")
|
||||
|
||||
def get_otp_type(self) -> OTPType:
|
||||
if self.type:
|
||||
return self.type
|
||||
value = self.value.strip().lower()
|
||||
if value.startswith("https://") or value.startswith("http://"):
|
||||
return OTPType.MAGIC_LINK
|
||||
return OTPType.TOTP
|
||||
|
||||
|
||||
class OTPResultParsedByLLM(BaseModel):
|
||||
reasoning: str = Field(..., description="The reasoning of the OTP code.")
|
||||
otp_type: OTPType | None = Field(None, description="The type of the OTP code.")
|
||||
otp_value_found: bool = Field(..., description="Whether the OTP value is found.")
|
||||
otp_value: str | None = Field(None, description="The OTP value.")
|
||||
|
||||
|
||||
async def parse_otp_login(content: str, organization_id: str) -> OTPValue | None:
|
||||
prompt = prompt_engine.load_prompt("parse-otp-login", content=content)
|
||||
resp = await app.SECONDARY_LLM_API_HANDLER(
|
||||
prompt=prompt, prompt_name="parse-otp-login", organization_id=organization_id
|
||||
)
|
||||
LOG.info("OTP Login Parser Response", resp=resp)
|
||||
otp_result = OTPResultParsedByLLM.model_validate(resp)
|
||||
if otp_result.otp_value_found and otp_result.otp_value:
|
||||
return OTPValue(value=otp_result.otp_value, type=otp_result.otp_type)
|
||||
return None
|
||||
|
||||
|
||||
async def poll_otp_value(
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
) -> OTPValue | None:
|
||||
timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
|
||||
start_datetime = datetime.utcnow()
|
||||
timeout_datetime = start_datetime + timeout
|
||||
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value)
|
||||
if not org_token:
|
||||
LOG.error("Failed to get organization token when trying to get otp value")
|
||||
return None
|
||||
LOG.info(
|
||||
"Polling otp value",
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
)
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
# check timeout
|
||||
if datetime.utcnow() > timeout_datetime:
|
||||
LOG.warning("Polling otp value timed out")
|
||||
raise NoTOTPVerificationCodeFound(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_permanent_id,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
)
|
||||
otp_value: OTPValue | None = None
|
||||
if totp_verification_url:
|
||||
otp_value = await _get_otp_value_from_url(
|
||||
totp_verification_url,
|
||||
org_token.token,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
elif totp_identifier:
|
||||
otp_value = await _get_otp_value_from_db(
|
||||
organization_id,
|
||||
totp_identifier,
|
||||
task_id=task_id,
|
||||
workflow_id=workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if otp_value:
|
||||
LOG.info("Got otp value", otp_value=otp_value)
|
||||
return otp_value
|
||||
|
||||
|
||||
async def _get_otp_value_from_url(
|
||||
url: str,
|
||||
api_key: str,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
) -> OTPValue | None:
|
||||
request_data = {}
|
||||
if task_id:
|
||||
request_data["task_id"] = task_id
|
||||
if workflow_run_id:
|
||||
request_data["workflow_run_id"] = workflow_run_id
|
||||
if workflow_permanent_id:
|
||||
request_data["workflow_permanent_id"] = workflow_permanent_id
|
||||
payload = json.dumps(request_data)
|
||||
signature = generate_skyvern_signature(
|
||||
payload=payload,
|
||||
api_key=api_key,
|
||||
)
|
||||
timestamp = str(int(datetime.utcnow().timestamp()))
|
||||
headers = {
|
||||
"x-skyvern-timestamp": timestamp,
|
||||
"x-skyvern-signature": signature,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
json_resp = await aiohttp_post(url=url, data=request_data, headers=headers, raise_exception=False)
|
||||
except Exception as e:
|
||||
LOG.error("Failed to get otp value from url", exc_info=True)
|
||||
raise FailedToGetTOTPVerificationCode(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_permanent_id,
|
||||
totp_verification_url=url,
|
||||
reason=str(e),
|
||||
)
|
||||
code = json_resp.get("verification_code", None)
|
||||
if code:
|
||||
return OTPValue(value=code, type=OTPType.TOTP)
|
||||
|
||||
magic_link = json_resp.get("magic_link", None)
|
||||
if magic_link:
|
||||
return OTPValue(value=magic_link, type=OTPType.MAGIC_LINK)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_otp_value_from_db(
|
||||
organization_id: str,
|
||||
totp_identifier: str,
|
||||
task_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> OTPValue | None:
|
||||
totp_codes = await app.DATABASE.get_otp_codes(organization_id=organization_id, totp_identifier=totp_identifier)
|
||||
for totp_code in totp_codes:
|
||||
if totp_code.workflow_run_id and workflow_run_id and totp_code.workflow_run_id != workflow_run_id:
|
||||
continue
|
||||
if totp_code.workflow_id and workflow_id and totp_code.workflow_id != workflow_id:
|
||||
continue
|
||||
if totp_code.task_id and totp_code.task_id != task_id:
|
||||
continue
|
||||
if totp_code.expired_at and totp_code.expired_at < datetime.utcnow():
|
||||
continue
|
||||
return OTPValue(value=totp_code.code, type=totp_code.otp_type)
|
||||
return None
|
||||
Reference in New Issue
Block a user