182 lines
6.6 KiB
Python
182 lines
6.6 KiB
Python
import asyncio
|
|
from datetime import datetime, timedelta
|
|
|
|
import structlog
|
|
from pydantic import BaseModel, Field
|
|
|
|
from skyvern.config import settings
|
|
from skyvern.exceptions import FailedToGetTOTPVerificationCode, NoTOTPVerificationCodeFound
|
|
from skyvern.forge import app
|
|
from skyvern.forge.prompts import prompt_engine
|
|
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
|
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
|
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
|
|
|
LOG = structlog.get_logger()
|
|
|
|
|
|
class OTPValue(BaseModel):
|
|
value: str = Field(..., description="The value of the OTP code.")
|
|
type: OTPType | None = Field(None, description="The type of the OTP code.")
|
|
|
|
def get_otp_type(self) -> OTPType:
|
|
if self.type:
|
|
return self.type
|
|
value = self.value.strip().lower()
|
|
if value.startswith("https://") or value.startswith("http://"):
|
|
return OTPType.MAGIC_LINK
|
|
return OTPType.TOTP
|
|
|
|
|
|
class OTPResultParsedByLLM(BaseModel):
|
|
reasoning: str = Field(..., description="The reasoning of the OTP code.")
|
|
otp_type: OTPType | None = Field(None, description="The type of the OTP code.")
|
|
otp_value_found: bool = Field(..., description="Whether the OTP value is found.")
|
|
otp_value: str | None = Field(None, description="The OTP value.")
|
|
|
|
|
|
async def parse_otp_login(content: str, organization_id: str) -> OTPValue | None:
|
|
prompt = prompt_engine.load_prompt("parse-otp-login", content=content)
|
|
resp = await app.SECONDARY_LLM_API_HANDLER(
|
|
prompt=prompt, prompt_name="parse-otp-login", organization_id=organization_id
|
|
)
|
|
LOG.info("OTP Login Parser Response", resp=resp)
|
|
otp_result = OTPResultParsedByLLM.model_validate(resp)
|
|
if otp_result.otp_value_found and otp_result.otp_value:
|
|
return OTPValue(value=otp_result.otp_value, type=otp_result.otp_type)
|
|
return None
|
|
|
|
|
|
async def poll_otp_value(
|
|
organization_id: str,
|
|
task_id: str | None = None,
|
|
workflow_id: str | None = None,
|
|
workflow_run_id: str | None = None,
|
|
workflow_permanent_id: str | None = None,
|
|
totp_verification_url: str | None = None,
|
|
totp_identifier: str | None = None,
|
|
) -> OTPValue | None:
|
|
timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
|
|
start_datetime = datetime.utcnow()
|
|
timeout_datetime = start_datetime + timeout
|
|
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value)
|
|
if not org_token:
|
|
LOG.error("Failed to get organization token when trying to get otp value")
|
|
return None
|
|
LOG.info(
|
|
"Polling otp value",
|
|
task_id=task_id,
|
|
workflow_run_id=workflow_run_id,
|
|
workflow_permanent_id=workflow_permanent_id,
|
|
totp_verification_url=totp_verification_url,
|
|
totp_identifier=totp_identifier,
|
|
)
|
|
while True:
|
|
await asyncio.sleep(10)
|
|
# check timeout
|
|
if datetime.utcnow() > timeout_datetime:
|
|
LOG.warning("Polling otp value timed out")
|
|
raise NoTOTPVerificationCodeFound(
|
|
task_id=task_id,
|
|
workflow_run_id=workflow_run_id,
|
|
workflow_id=workflow_permanent_id,
|
|
totp_verification_url=totp_verification_url,
|
|
totp_identifier=totp_identifier,
|
|
)
|
|
otp_value: OTPValue | None = None
|
|
if totp_verification_url:
|
|
otp_value = await _get_otp_value_from_url(
|
|
organization_id,
|
|
totp_verification_url,
|
|
org_token.token,
|
|
task_id=task_id,
|
|
workflow_run_id=workflow_run_id,
|
|
)
|
|
elif totp_identifier:
|
|
otp_value = await _get_otp_value_from_db(
|
|
organization_id,
|
|
totp_identifier,
|
|
task_id=task_id,
|
|
workflow_id=workflow_permanent_id,
|
|
workflow_run_id=workflow_run_id,
|
|
)
|
|
if otp_value:
|
|
LOG.info("Got otp value", otp_value=otp_value)
|
|
return otp_value
|
|
|
|
|
|
async def _get_otp_value_from_url(
|
|
organization_id: str,
|
|
url: str,
|
|
api_key: str,
|
|
task_id: str | None = None,
|
|
workflow_run_id: str | None = None,
|
|
workflow_permanent_id: str | None = None,
|
|
) -> OTPValue | None:
|
|
request_data = {}
|
|
if task_id:
|
|
request_data["task_id"] = task_id
|
|
if workflow_run_id:
|
|
request_data["workflow_run_id"] = workflow_run_id
|
|
if workflow_permanent_id:
|
|
request_data["workflow_permanent_id"] = workflow_permanent_id
|
|
signed_data = generate_skyvern_webhook_signature(
|
|
payload=request_data,
|
|
api_key=api_key,
|
|
)
|
|
try:
|
|
json_resp = await aiohttp_post(url=url, data=request_data, headers=signed_data.headers, raise_exception=False)
|
|
except Exception as e:
|
|
LOG.error("Failed to get otp value from url", exc_info=True)
|
|
raise FailedToGetTOTPVerificationCode(
|
|
task_id=task_id,
|
|
workflow_run_id=workflow_run_id,
|
|
workflow_id=workflow_permanent_id,
|
|
totp_verification_url=url,
|
|
reason=str(e),
|
|
)
|
|
if not json_resp:
|
|
return None
|
|
|
|
content = json_resp.get("verification_code", None)
|
|
if not content:
|
|
return None
|
|
|
|
otp_value: OTPValue | None = OTPValue(value=content, type=OTPType.TOTP)
|
|
if isinstance(content, str) and len(content) > 10:
|
|
try:
|
|
otp_value = await parse_otp_login(content, organization_id)
|
|
except Exception:
|
|
LOG.warning("faile to parse content by LLM call", exc_info=True)
|
|
|
|
if not otp_value:
|
|
LOG.warning(
|
|
"Failed to parse otp login from the totp url",
|
|
content=content,
|
|
)
|
|
return None
|
|
|
|
return otp_value
|
|
|
|
|
|
async def _get_otp_value_from_db(
|
|
organization_id: str,
|
|
totp_identifier: str,
|
|
task_id: str | None = None,
|
|
workflow_id: str | None = None,
|
|
workflow_run_id: str | None = None,
|
|
) -> OTPValue | None:
|
|
totp_codes = await app.DATABASE.get_otp_codes(organization_id=organization_id, totp_identifier=totp_identifier)
|
|
for totp_code in totp_codes:
|
|
if totp_code.workflow_run_id and workflow_run_id and totp_code.workflow_run_id != workflow_run_id:
|
|
continue
|
|
if totp_code.workflow_id and workflow_id and totp_code.workflow_id != workflow_id:
|
|
continue
|
|
if totp_code.task_id and totp_code.task_id != task_id:
|
|
continue
|
|
if totp_code.expired_at and totp_code.expired_at < datetime.utcnow():
|
|
continue
|
|
return OTPValue(value=totp_code.code, type=totp_code.otp_type)
|
|
return None
|