store totp_identifier to credentials with fallback for login runs (#4154)
This commit is contained in:
@@ -4515,6 +4515,7 @@ class AgentDB:
|
||||
totp_type: str,
|
||||
card_last4: str | None,
|
||||
card_brand: str | None,
|
||||
totp_identifier: str | None = None,
|
||||
) -> Credential:
|
||||
async with self.Session() as session:
|
||||
credential = CredentialModel(
|
||||
@@ -4525,6 +4526,7 @@ class AgentDB:
|
||||
credential_type=credential_type,
|
||||
username=username,
|
||||
totp_type=totp_type,
|
||||
totp_identifier=totp_identifier,
|
||||
card_last4=card_last4,
|
||||
card_brand=card_brand,
|
||||
)
|
||||
|
||||
@@ -94,12 +94,15 @@ async def login(
|
||||
label = "login"
|
||||
yaml_parameters = []
|
||||
parameter_key = "credential"
|
||||
resolved_totp_identifier = login_request.totp_identifier
|
||||
if login_request.credential_type == CredentialType.skyvern:
|
||||
if not login_request.credential_id:
|
||||
raise HTTPException(status_code=400, detail="credential_id is required to login with Skyvern credential")
|
||||
credential = await app.DATABASE.get_credential(login_request.credential_id, organization.organization_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail=f"Credential {login_request.credential_id} not found")
|
||||
if not resolved_totp_identifier:
|
||||
resolved_totp_identifier = credential.totp_identifier
|
||||
|
||||
yaml_parameters = [
|
||||
WorkflowParameterYAML(
|
||||
@@ -169,7 +172,7 @@ async def login(
|
||||
max_steps_per_run=10,
|
||||
parameter_keys=[parameter_key],
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=login_request.totp_identifier,
|
||||
totp_identifier=resolved_totp_identifier,
|
||||
)
|
||||
yaml_blocks = [login_block_yaml]
|
||||
workflow_definition_yaml = WorkflowDefinitionYAML(
|
||||
@@ -198,7 +201,7 @@ async def login(
|
||||
legacy_workflow_request = WorkflowRequestBody(
|
||||
proxy_location=login_request.proxy_location,
|
||||
webhook_callback_url=webhook_url,
|
||||
totp_identifier=login_request.totp_identifier,
|
||||
totp_identifier=resolved_totp_identifier,
|
||||
totp_verification_url=totp_verification_url,
|
||||
browser_session_id=login_request.browser_session_id,
|
||||
browser_profile_id=login_request.browser_profile_id,
|
||||
@@ -235,7 +238,7 @@ async def login(
|
||||
proxy_location=login_request.proxy_location,
|
||||
webhook_url=webhook_url,
|
||||
totp_url=totp_verification_url,
|
||||
totp_identifier=login_request.totp_identifier,
|
||||
totp_identifier=resolved_totp_identifier,
|
||||
browser_session_id=login_request.browser_session_id,
|
||||
browser_profile_id=login_request.browser_profile_id,
|
||||
max_screenshot_scrolls=login_request.max_screenshot_scrolling_times,
|
||||
|
||||
@@ -34,6 +34,11 @@ class PasswordCredentialResponse(BaseModel):
|
||||
description="Type of 2FA method used for this credential",
|
||||
examples=[TotpType.AUTHENTICATOR],
|
||||
)
|
||||
totp_identifier: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier (email or phone number) used to fetch TOTP codes",
|
||||
examples=["user@example.com", "+14155550123"],
|
||||
)
|
||||
|
||||
|
||||
class CreditCardCredentialResponse(BaseModel):
|
||||
@@ -58,6 +63,11 @@ class PasswordCredential(BaseModel):
|
||||
description="Type of 2FA method used for this credential",
|
||||
examples=[TotpType.AUTHENTICATOR],
|
||||
)
|
||||
totp_identifier: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier (email or phone number) used to fetch TOTP codes",
|
||||
examples=["user@example.com", "+14155550123"],
|
||||
)
|
||||
|
||||
|
||||
class NonEmptyPasswordCredential(PasswordCredential):
|
||||
@@ -155,6 +165,11 @@ class Credential(BaseModel):
|
||||
description="Type of 2FA method used for this credential",
|
||||
examples=[TotpType.AUTHENTICATOR],
|
||||
)
|
||||
totp_identifier: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier (email or phone number) used to fetch TOTP codes",
|
||||
examples=["user@example.com", "+14155550123"],
|
||||
)
|
||||
card_last4: str | None = Field(..., description="For credit_card credentials: the last four digits of the card")
|
||||
card_brand: str | None = Field(..., description="For credit_card credentials: the card brand")
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ class CredentialVaultService(ABC):
|
||||
credential_type=data.credential_type,
|
||||
username=data.credential.username,
|
||||
totp_type=data.credential.totp_type,
|
||||
totp_identifier=data.credential.totp_identifier,
|
||||
card_last4=None,
|
||||
card_brand=None,
|
||||
)
|
||||
@@ -65,6 +66,7 @@ class CredentialVaultService(ABC):
|
||||
totp_type="none",
|
||||
card_last4=data.credential.card_number[-4:],
|
||||
card_brand=data.credential.card_brand,
|
||||
totp_identifier=None,
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unsupported credential type: {data.credential_type}")
|
||||
|
||||
@@ -173,6 +173,7 @@ class WorkflowRunContext:
|
||||
self._aws_client = aws_client
|
||||
self.organization_id: str | None = None
|
||||
self.include_secrets_in_templates: bool = False
|
||||
self.credential_totp_identifiers: dict[str, str] = {}
|
||||
|
||||
def get_parameter(self, key: str) -> Parameter:
|
||||
return self.parameters[key]
|
||||
@@ -295,6 +296,10 @@ class WorkflowRunContext:
|
||||
credential_item = await credential_service.get_credential_item(db_credential)
|
||||
credential = credential_item.credential
|
||||
|
||||
credential_totp_identifier = getattr(credential, "totp_identifier", None)
|
||||
if credential_totp_identifier:
|
||||
self.credential_totp_identifiers[parameter.key] = credential_totp_identifier
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
self.values[parameter.key] = {
|
||||
"context": "These values are placeholders. When you type this in, the real value gets inserted (For security reasons)",
|
||||
@@ -319,6 +324,9 @@ class WorkflowRunContext:
|
||||
self.secrets[totp_secret_value] = parse_totp_secret(credential.totp)
|
||||
self.values[parameter.key]["totp"] = totp_secret_id
|
||||
|
||||
def get_credential_totp_identifier(self, parameter_key: str) -> str | None:
|
||||
return self.credential_totp_identifiers.get(parameter_key)
|
||||
|
||||
async def register_secret_workflow_parameter_value(
|
||||
self,
|
||||
parameter: WorkflowParameter,
|
||||
|
||||
@@ -588,19 +588,22 @@ class BaseTaskBlock(Block):
|
||||
)
|
||||
self.url = task_url_parameter_value
|
||||
|
||||
if (
|
||||
self.totp_identifier
|
||||
and workflow_run_context.has_parameter(self.totp_identifier)
|
||||
and workflow_run_context.has_value(self.totp_identifier)
|
||||
):
|
||||
totp_identifier_parameter_value = workflow_run_context.get_value(self.totp_identifier)
|
||||
if totp_identifier_parameter_value:
|
||||
LOG.info(
|
||||
"TOTP identifier is parameterized, using parameter value",
|
||||
totp_identifier_parameter_value=totp_identifier_parameter_value,
|
||||
totp_identifier_parameter_key=self.totp_identifier,
|
||||
)
|
||||
self.totp_identifier = totp_identifier_parameter_value
|
||||
if self.totp_identifier:
|
||||
if workflow_run_context.has_parameter(self.totp_identifier) and workflow_run_context.has_value(
|
||||
self.totp_identifier
|
||||
):
|
||||
totp_identifier_parameter_value = workflow_run_context.get_value(self.totp_identifier)
|
||||
if totp_identifier_parameter_value:
|
||||
self.totp_identifier = totp_identifier_parameter_value
|
||||
else:
|
||||
for parameter in self.get_all_parameters(workflow_run_id):
|
||||
parameter_key = getattr(parameter, "key", None)
|
||||
if not parameter_key:
|
||||
continue
|
||||
credential_totp_identifier = workflow_run_context.get_credential_totp_identifier(parameter_key)
|
||||
if credential_totp_identifier:
|
||||
self.totp_identifier = credential_totp_identifier
|
||||
break
|
||||
|
||||
if self.download_suffix and workflow_run_context.has_parameter(self.download_suffix):
|
||||
download_suffix_parameter_value = workflow_run_context.get_value(self.download_suffix)
|
||||
|
||||
Reference in New Issue
Block a user