store totp_identifier to credentials with fallback for login runs (#4154)

This commit is contained in:
Marc Kelechava
2025-12-01 16:19:37 -08:00
committed by GitHub
parent acce1c869d
commit 7100b7e004
6 changed files with 49 additions and 16 deletions

View File

@@ -4515,6 +4515,7 @@ class AgentDB:
totp_type: str,
card_last4: str | None,
card_brand: str | None,
totp_identifier: str | None = None,
) -> Credential:
async with self.Session() as session:
credential = CredentialModel(
@@ -4525,6 +4526,7 @@ class AgentDB:
credential_type=credential_type,
username=username,
totp_type=totp_type,
totp_identifier=totp_identifier,
card_last4=card_last4,
card_brand=card_brand,
)

View File

@@ -94,12 +94,15 @@ async def login(
label = "login"
yaml_parameters = []
parameter_key = "credential"
resolved_totp_identifier = login_request.totp_identifier
if login_request.credential_type == CredentialType.skyvern:
if not login_request.credential_id:
raise HTTPException(status_code=400, detail="credential_id is required to login with Skyvern credential")
credential = await app.DATABASE.get_credential(login_request.credential_id, organization.organization_id)
if not credential:
raise HTTPException(status_code=404, detail=f"Credential {login_request.credential_id} not found")
if not resolved_totp_identifier:
resolved_totp_identifier = credential.totp_identifier
yaml_parameters = [
WorkflowParameterYAML(
@@ -169,7 +172,7 @@ async def login(
max_steps_per_run=10,
parameter_keys=[parameter_key],
totp_verification_url=totp_verification_url,
totp_identifier=login_request.totp_identifier,
totp_identifier=resolved_totp_identifier,
)
yaml_blocks = [login_block_yaml]
workflow_definition_yaml = WorkflowDefinitionYAML(
@@ -198,7 +201,7 @@ async def login(
legacy_workflow_request = WorkflowRequestBody(
proxy_location=login_request.proxy_location,
webhook_callback_url=webhook_url,
totp_identifier=login_request.totp_identifier,
totp_identifier=resolved_totp_identifier,
totp_verification_url=totp_verification_url,
browser_session_id=login_request.browser_session_id,
browser_profile_id=login_request.browser_profile_id,
@@ -235,7 +238,7 @@ async def login(
proxy_location=login_request.proxy_location,
webhook_url=webhook_url,
totp_url=totp_verification_url,
totp_identifier=login_request.totp_identifier,
totp_identifier=resolved_totp_identifier,
browser_session_id=login_request.browser_session_id,
browser_profile_id=login_request.browser_profile_id,
max_screenshot_scrolls=login_request.max_screenshot_scrolling_times,

View File

@@ -34,6 +34,11 @@ class PasswordCredentialResponse(BaseModel):
description="Type of 2FA method used for this credential",
examples=[TotpType.AUTHENTICATOR],
)
totp_identifier: str | None = Field(
default=None,
description="Identifier (email or phone number) used to fetch TOTP codes",
examples=["user@example.com", "+14155550123"],
)
class CreditCardCredentialResponse(BaseModel):
@@ -58,6 +63,11 @@ class PasswordCredential(BaseModel):
description="Type of 2FA method used for this credential",
examples=[TotpType.AUTHENTICATOR],
)
totp_identifier: str | None = Field(
default=None,
description="Identifier (email or phone number) used to fetch TOTP codes",
examples=["user@example.com", "+14155550123"],
)
class NonEmptyPasswordCredential(PasswordCredential):
@@ -155,6 +165,11 @@ class Credential(BaseModel):
description="Type of 2FA method used for this credential",
examples=[TotpType.AUTHENTICATOR],
)
totp_identifier: str | None = Field(
default=None,
description="Identifier (email or phone number) used to fetch TOTP codes",
examples=["user@example.com", "+14155550123"],
)
card_last4: str | None = Field(..., description="For credit_card credentials: the last four digits of the card")
card_brand: str | None = Field(..., description="For credit_card credentials: the card brand")

View File

@@ -51,6 +51,7 @@ class CredentialVaultService(ABC):
credential_type=data.credential_type,
username=data.credential.username,
totp_type=data.credential.totp_type,
totp_identifier=data.credential.totp_identifier,
card_last4=None,
card_brand=None,
)
@@ -65,6 +66,7 @@ class CredentialVaultService(ABC):
totp_type="none",
card_last4=data.credential.card_number[-4:],
card_brand=data.credential.card_brand,
totp_identifier=None,
)
else:
raise Exception(f"Unsupported credential type: {data.credential_type}")

View File

@@ -173,6 +173,7 @@ class WorkflowRunContext:
self._aws_client = aws_client
self.organization_id: str | None = None
self.include_secrets_in_templates: bool = False
self.credential_totp_identifiers: dict[str, str] = {}
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
@@ -295,6 +296,10 @@ class WorkflowRunContext:
credential_item = await credential_service.get_credential_item(db_credential)
credential = credential_item.credential
credential_totp_identifier = getattr(credential, "totp_identifier", None)
if credential_totp_identifier:
self.credential_totp_identifiers[parameter.key] = credential_totp_identifier
self.parameters[parameter.key] = parameter
self.values[parameter.key] = {
"context": "These values are placeholders. When you type this in, the real value gets inserted (For security reasons)",
@@ -319,6 +324,9 @@ class WorkflowRunContext:
self.secrets[totp_secret_value] = parse_totp_secret(credential.totp)
self.values[parameter.key]["totp"] = totp_secret_id
def get_credential_totp_identifier(self, parameter_key: str) -> str | None:
return self.credential_totp_identifiers.get(parameter_key)
async def register_secret_workflow_parameter_value(
self,
parameter: WorkflowParameter,

View File

@@ -588,19 +588,22 @@ class BaseTaskBlock(Block):
)
self.url = task_url_parameter_value
if (
self.totp_identifier
and workflow_run_context.has_parameter(self.totp_identifier)
and workflow_run_context.has_value(self.totp_identifier)
):
totp_identifier_parameter_value = workflow_run_context.get_value(self.totp_identifier)
if totp_identifier_parameter_value:
LOG.info(
"TOTP identifier is parameterized, using parameter value",
totp_identifier_parameter_value=totp_identifier_parameter_value,
totp_identifier_parameter_key=self.totp_identifier,
)
self.totp_identifier = totp_identifier_parameter_value
if self.totp_identifier:
if workflow_run_context.has_parameter(self.totp_identifier) and workflow_run_context.has_value(
self.totp_identifier
):
totp_identifier_parameter_value = workflow_run_context.get_value(self.totp_identifier)
if totp_identifier_parameter_value:
self.totp_identifier = totp_identifier_parameter_value
else:
for parameter in self.get_all_parameters(workflow_run_id):
parameter_key = getattr(parameter, "key", None)
if not parameter_key:
continue
credential_totp_identifier = workflow_run_context.get_credential_totp_identifier(parameter_key)
if credential_totp_identifier:
self.totp_identifier = credential_totp_identifier
break
if self.download_suffix and workflow_run_context.has_parameter(self.download_suffix):
download_suffix_parameter_value = workflow_run_context.get_value(self.download_suffix)