From b9f5e33876f5097d0e90cea7d3a471a436101366 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Sun, 8 Sep 2024 15:07:03 -0700 Subject: [PATCH] TOTP code db + agent support for fetching totp_code from db (#784) Co-authored-by: Shuchang Zheng --- ...1_create_totp_codes_table_and_add_task_.py | 69 +++++++++++++++++++ skyvern/config.py | 3 + skyvern/forge/agent.py | 15 ++-- skyvern/forge/sdk/db/client.py | 32 +++++++++ skyvern/forge/sdk/db/id.py | 5 ++ skyvern/forge/sdk/db/models.py | 20 ++++++ skyvern/forge/sdk/db/utils.py | 3 + skyvern/forge/sdk/schemas/tasks.py | 1 + skyvern/forge/sdk/schemas/totp_codes.py | 29 ++++++++ skyvern/forge/sdk/workflow/models/block.py | 2 + skyvern/forge/sdk/workflow/models/workflow.py | 4 ++ skyvern/forge/sdk/workflow/models/yaml.py | 7 +- skyvern/forge/sdk/workflow/service.py | 11 ++- skyvern/webeye/actions/handler.py | 68 +++++++++++++----- 14 files changed, 243 insertions(+), 26 deletions(-) create mode 100644 alembic/versions/2024_09_08_2159-c5848cc524b1_create_totp_codes_table_and_add_task_.py create mode 100644 skyvern/forge/sdk/schemas/totp_codes.py diff --git a/alembic/versions/2024_09_08_2159-c5848cc524b1_create_totp_codes_table_and_add_task_.py b/alembic/versions/2024_09_08_2159-c5848cc524b1_create_totp_codes_table_and_add_task_.py new file mode 100644 index 00000000..fcd218dc --- /dev/null +++ b/alembic/versions/2024_09_08_2159-c5848cc524b1_create_totp_codes_table_and_add_task_.py @@ -0,0 +1,69 @@ +"""create totp_codes table and add task.totp_identifier + +Revision ID: c5848cc524b1 +Revises: c50f0aa0ef24 +Create Date: 2024-09-08 21:59:56.666276+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c5848cc524b1" +down_revision: Union[str, None] = "c50f0aa0ef24" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "totp_codes", + sa.Column("totp_code_id", sa.String(), nullable=False), + sa.Column("totp_identifier", sa.String(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=True), + sa.Column("task_id", sa.String(), nullable=True), + sa.Column("workflow_id", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=False), + sa.Column("code", sa.String(), nullable=False), + sa.Column("source", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("modified_at", sa.DateTime(), nullable=False), + sa.Column("expired_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.organization_id"], + ), + sa.ForeignKeyConstraint( + ["task_id"], + ["tasks.task_id"], + ), + sa.ForeignKeyConstraint( + ["workflow_id"], + ["workflows.workflow_id"], + ), + sa.PrimaryKeyConstraint("totp_code_id"), + ) + op.create_index(op.f("ix_totp_codes_created_at"), "totp_codes", ["created_at"], unique=False) + op.create_index(op.f("ix_totp_codes_expired_at"), "totp_codes", ["expired_at"], unique=False) + op.create_index(op.f("ix_totp_codes_totp_identifier"), "totp_codes", ["totp_identifier"], unique=False) + op.add_column("tasks", sa.Column("totp_identifier", sa.String(), nullable=True)) + op.add_column("workflow_runs", sa.Column("totp_identifier", sa.String(), nullable=True)) + op.add_column("workflows", sa.Column("totp_identifier", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("workflows", "totp_identifier") + op.drop_column("workflow_runs", "totp_identifier") + op.drop_column("tasks", "totp_identifier") + op.drop_index(op.f("ix_totp_codes_totp_identifier"), table_name="totp_codes") + op.drop_index(op.f("ix_totp_codes_expired_at"), table_name="totp_codes") + op.drop_index(op.f("ix_totp_codes_created_at"), table_name="totp_codes") + op.drop_table("totp_codes") + # ### end Alembic commands ### diff --git a/skyvern/config.py b/skyvern/config.py index d5049699..c6bf36bc 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -117,6 +117,9 @@ class Settings(BaseSettings): AZURE_GPT4O_MINI_API_BASE: str | None = None AZURE_GPT4O_MINI_API_VERSION: str | None = None + # TOTP Settings + TOTP_LIFESPAN_MINUTES: int = 10 + def is_cloud_environment(self) -> bool: """ :return: True if env is not local, else False diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index a1e063d4..5fd5cfba 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -122,7 +122,8 @@ class ForgeAgent: url=task_url, title=task_block.title, webhook_callback_url=None, - totp_verification_url=None, + totp_verification_url=task_block.totp_verification_url, + totp_identifier=task_block.totp_identifier, navigation_goal=task_block.navigation_goal, data_extraction_goal=task_block.data_extraction_goal, navigation_payload=navigation_payload, @@ -178,6 +179,7 @@ class ForgeAgent: title=task_request.title, webhook_callback_url=task_request.webhook_callback_url, totp_verification_url=task_request.totp_verification_url, + totp_identifier=task_request.totp_identifier, navigation_goal=task_request.navigation_goal, data_extraction_goal=task_request.data_extraction_goal, navigation_payload=task_request.navigation_payload, @@ -983,7 +985,7 @@ class ForgeAgent: task, browser_state, element_tree_in_prompt, - verification_code_check=bool(task.totp_verification_url), + verification_code_check=bool(task.totp_verification_url or task.totp_identifier), expire_verification_code=True, ) @@ -1055,7 +1057,7 @@ class ForgeAgent: final_navigation_payload = task.navigation_payload current_context = skyvern_context.ensure_context() verification_code = current_context.totp_codes.get(task.task_id) - if task.totp_verification_url and verification_code: + if (task.totp_verification_url or task.totp_identifier) and verification_code: if ( isinstance(final_navigation_payload, dict) and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload @@ -1598,10 +1600,13 @@ class ForgeAgent: json_response: dict[str, Any], ) -> dict[str, Any]: need_verification_code = json_response.get("need_verification_code") - if need_verification_code and task.totp_verification_url and task.organization_id: + if need_verification_code and (task.totp_verification_url or task.totp_identifier) and task.organization_id: LOG.info("Need verification code", step_id=step.step_id) verification_code = await poll_verification_code( - task.task_id, task.organization_id, url=task.totp_verification_url + task.task_id, + task.organization_id, + totp_verification_url=task.totp_verification_url, + totp_identifier=task.totp_identifier, ) current_context = skyvern_context.ensure_context() current_context.totp_codes[task.task_id] = verification_code diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 8c4c12ef..9656beb9 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -22,6 +22,7 @@ from skyvern.forge.sdk.db.models import ( StepModel, TaskGenerationModel, TaskModel, + TOTPCodeModel, WorkflowModel, WorkflowParameterModel, WorkflowRunModel, @@ -48,6 +49,7 @@ from skyvern.forge.sdk.db.utils import ( from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus from skyvern.forge.sdk.schemas.task_generations import TaskGeneration from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus +from skyvern.forge.sdk.schemas.totp_codes import TOTPCode from skyvern.forge.sdk.workflow.models.parameter import ( AWSSecretParameter, BitwardenLoginCredentialParameter, @@ -84,6 +86,7 @@ class AgentDB: navigation_payload: dict[str, Any] | list | str | None, webhook_callback_url: str | None = None, totp_verification_url: str | None = None, + totp_identifier: str | None = None, organization_id: str | None = None, proxy_location: ProxyLocation | None = None, extracted_information_schema: dict[str, Any] | list | str | None = None, @@ -101,6 +104,7 @@ class AgentDB: title=title, webhook_callback_url=webhook_callback_url, totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, navigation_goal=navigation_goal, data_extraction_goal=data_extraction_goal, navigation_payload=navigation_payload, @@ -819,6 +823,7 @@ class AgentDB: proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None, totp_verification_url: str | None = None, + totp_identifier: str | None = None, persist_browser_session: bool = False, workflow_permanent_id: str | None = None, version: int | None = None, @@ -833,6 +838,7 @@ class AgentDB: proxy_location=proxy_location, webhook_callback_url=webhook_callback_url, totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, persist_browser_session=persist_browser_session, is_saved_task=is_saved_task, ) @@ -1001,6 +1007,7 @@ class AgentDB: proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None, totp_verification_url: str | None = None, + totp_identifier: str | None = None, ) -> WorkflowRun: try: async with self.Session() as session: @@ -1012,6 +1019,7 @@ class AgentDB: status="created", webhook_callback_url=webhook_callback_url, totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, ) session.add(workflow_run) await session.commit() @@ -1439,3 +1447,27 @@ class AgentDB: if not task_generation: return None return TaskGeneration.model_validate(task_generation) + + async def get_totp_codes( + self, + organization_id: str, + totp_identifier: str, + valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES, + ) -> list[TOTPCode]: + """ + 1. filter by: + - organization_id + - totp_identifier + 2. make sure created_at is within the valid lifespan + 3. sort by created_at desc + """ + async with self.Session() as session: + query = ( + select(TOTPCodeModel) + .filter_by(organization_id=organization_id) + .filter_by(totp_identifier=totp_identifier) + .filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes)) + .order_by(TOTPCodeModel.created_at.desc()) + ) + totp_code = (await session.scalars(query)).all() + return [TOTPCode.model_validate(totp_code) for totp_code in totp_code] diff --git a/skyvern/forge/sdk/db/id.py b/skyvern/forge/sdk/db/id.py index d31ad69c..ade5bc21 100644 --- a/skyvern/forge/sdk/db/id.py +++ b/skyvern/forge/sdk/db/id.py @@ -119,6 +119,11 @@ def generate_task_generation_id() -> str: return f"{TASK_GENERATION_PREFIX}_{int_id}" +def generate_totp_code_id() -> str: + int_id = generate_id() + return f"totp_{int_id}" + + def generate_id() -> int: """ generate a 64-bit int ID diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index a79ee48d..04c4d771 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -29,6 +29,7 @@ from skyvern.forge.sdk.db.id import ( generate_step_id, generate_task_generation_id, generate_task_id, + generate_totp_code_id, generate_workflow_id, generate_workflow_parameter_id, generate_workflow_permanent_id, @@ -49,6 +50,7 @@ class TaskModel(Base): status = Column(String, index=True) webhook_callback_url = Column(String) totp_verification_url = Column(String) + totp_identifier = Column(String) title = Column(String) url = Column(String) navigation_goal = Column(String) @@ -180,6 +182,7 @@ class WorkflowModel(Base): proxy_location = Column(Enum(ProxyLocation)) webhook_callback_url = Column(String) totp_verification_url = Column(String) + totp_identifier = Column(String) persist_browser_session = Column(Boolean, default=False, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) @@ -207,6 +210,7 @@ class WorkflowRunModel(Base): proxy_location = Column(Enum(ProxyLocation)) webhook_callback_url = Column(String) totp_verification_url = Column(String) + totp_identifier = Column(String) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column( @@ -392,3 +396,19 @@ class TaskGenerationModel(Base): created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) + + +class TOTPCodeModel(Base): + __tablename__ = "totp_codes" + + totp_code_id = Column(String, primary_key=True, default=generate_totp_code_id) + totp_identifier = Column(String, nullable=False, index=True) + organization_id = Column(String, ForeignKey("organizations.organization_id")) + task_id = Column(String, ForeignKey("tasks.task_id")) + workflow_id = Column(String, ForeignKey("workflows.workflow_id")) + content = Column(String, nullable=False) + code = Column(String, nullable=False) + source = Column(String) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True) + modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) + expired_at = Column(DateTime, index=True) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index bf841f81..ce662df2 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -64,6 +64,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task: url=task_obj.url, webhook_callback_url=task_obj.webhook_callback_url, totp_verification_url=task_obj.totp_verification_url, + totp_identifier=task_obj.totp_identifier, navigation_goal=task_obj.navigation_goal, data_extraction_goal=task_obj.data_extraction_goal, navigation_payload=task_obj.navigation_payload, @@ -162,6 +163,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal workflow_permanent_id=workflow_model.workflow_permanent_id, webhook_callback_url=workflow_model.webhook_callback_url, totp_verification_url=workflow_model.totp_verification_url, + totp_identifier=workflow_model.totp_identifier, persist_browser_session=workflow_model.persist_browser_session, proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None), version=workflow_model.version, @@ -192,6 +194,7 @@ def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: ), webhook_callback_url=workflow_run_model.webhook_callback_url, totp_verification_url=workflow_run_model.totp_verification_url, + totp_identifier=workflow_run_model.totp_identifier, created_at=workflow_run_model.created_at, modified_at=workflow_run_model.modified_at, ) diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index fa9afb4d..a0d74b27 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -39,6 +39,7 @@ class TaskRequest(BaseModel): examples=["https://my-webhook.com"], ) totp_verification_url: str | None = None + totp_identifier: str | None = None navigation_goal: str | None = Field( default=None, description="The user's goal for the task.", diff --git a/skyvern/forge/sdk/schemas/totp_codes.py b/skyvern/forge/sdk/schemas/totp_codes.py new file mode 100644 index 00000000..3ca1dd36 --- /dev/null +++ b/skyvern/forge/sdk/schemas/totp_codes.py @@ -0,0 +1,29 @@ +from datetime import datetime + +from pydantic import BaseModel, ConfigDict + + +class TOTPCodeBase(BaseModel): + model_config = ConfigDict(from_attributes=True) + + totp_identifier: str | None = None + organization_id: str | None = None + task_id: str | None = None + workflow_id: str | None = None + source: str | None = None + content: str | None = None + + expired_at: datetime | None = None + + +class TOTPCodeCreate(TOTPCodeBase): + totp_identifier: str + organization_id: str + content: str + + +class TOTPCode(TOTPCodeCreate): + totp_code_id: str + code: str + created_at: datetime + modified_at: datetime diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index d46ed680..d31b9697 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -176,6 +176,8 @@ class TaskBlock(Block): max_steps_per_run: int | None = None parameters: list[PARAMETER_TYPE] = [] complete_on_download: bool = False + totp_verification_url: str | None = None + totp_identifier: str | None = None def get_all_parameters( self, diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index a30c0359..06c77970 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -15,6 +15,7 @@ class WorkflowRequestBody(BaseModel): proxy_location: ProxyLocation | None = None webhook_callback_url: str | None = None totp_verification_url: str | None = None + totp_identifier: str | None = None class RunWorkflowResponse(BaseModel): @@ -51,6 +52,7 @@ class Workflow(BaseModel): proxy_location: ProxyLocation | None = None webhook_callback_url: str | None = None totp_verification_url: str | None = None + totp_identifier: str | None = None persist_browser_session: bool = False created_at: datetime @@ -75,6 +77,7 @@ class WorkflowRun(BaseModel): proxy_location: ProxyLocation | None = None webhook_callback_url: str | None = None totp_verification_url: str | None = None + totp_identifier: str | None = None created_at: datetime modified_at: datetime @@ -101,6 +104,7 @@ class WorkflowRunStatusResponse(BaseModel): proxy_location: ProxyLocation | None = None webhook_callback_url: str | None = None totp_verification_url: str | None = None + totp_identifier: str | None = None created_at: datetime modified_at: datetime parameters: dict[str, Any] diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index 31500168..b98f198a 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -46,9 +46,7 @@ class BitwardenSensitiveInformationParameterYAML(ParameterYAML): # Parameter 1 of Literal[...] cannot be of type "Any" # This pattern already works in block.py but since the ParameterType is not defined in this file, mypy is not able # to infer the type of the parameter_type attribute. - parameter_type: Literal[ParameterType.BITWARDEN_SENSITIVE_INFORMATION] = ( - ParameterType.BITWARDEN_SENSITIVE_INFORMATION - ) # type: ignore + parameter_type: Literal["bitwarden_sensitive_information"] = ParameterType.BITWARDEN_SENSITIVE_INFORMATION # type: ignore # bitwarden cli required fields bitwarden_client_id_aws_secret_key: str @@ -113,6 +111,8 @@ class TaskBlockYAML(BlockYAML): max_steps_per_run: int | None = None parameter_keys: list[str] | None = None complete_on_download: bool = False + totp_verification_url: str | None = None + totp_identifier: str | None = None class ForLoopBlockYAML(BlockYAML): @@ -225,6 +225,7 @@ class WorkflowCreateYAMLRequest(BaseModel): proxy_location: ProxyLocation | None = None webhook_callback_url: str | None = None totp_verification_url: str | None = None + totp_identifier: str | None = None persist_browser_session: bool = False workflow_definition: WorkflowDefinitionYAML is_saved_task: bool = False diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 79750890..cf0d1405 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -286,6 +286,7 @@ class WorkflowService: proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None, totp_verification_url: str | None = None, + totp_identifier: str | None = None, persist_browser_session: bool = False, workflow_permanent_id: str | None = None, version: int | None = None, @@ -299,6 +300,7 @@ class WorkflowService: proxy_location=proxy_location, webhook_callback_url=webhook_callback_url, totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, persist_browser_session=persist_browser_session, workflow_permanent_id=workflow_permanent_id, version=version, @@ -397,6 +399,7 @@ class WorkflowService: proxy_location=workflow_request.proxy_location, webhook_callback_url=workflow_request.webhook_callback_url, totp_verification_url=workflow_request.totp_verification_url, + totp_identifier=workflow_request.totp_identifier, ) async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> None: @@ -640,6 +643,7 @@ class WorkflowService: proxy_location=workflow_run.proxy_location, webhook_callback_url=workflow_run.webhook_callback_url, totp_verification_url=workflow_run.totp_verification_url, + totp_identifier=workflow_run.totp_identifier, created_at=workflow_run.created_at, modified_at=workflow_run.modified_at, parameters=parameters_with_value, @@ -835,6 +839,7 @@ class WorkflowService: proxy_location=request.proxy_location, webhook_callback_url=request.webhook_callback_url, totp_verification_url=request.totp_verification_url, + totp_identifier=request.totp_identifier, persist_browser_session=request.persist_browser_session, workflow_permanent_id=workflow_permanent_id, version=existing_version + 1, @@ -849,6 +854,7 @@ class WorkflowService: proxy_location=request.proxy_location, webhook_callback_url=request.webhook_callback_url, totp_verification_url=request.totp_verification_url, + totp_identifier=request.totp_identifier, persist_browser_session=request.persist_browser_session, is_saved_task=request.is_saved_task, ) @@ -912,7 +918,8 @@ class WorkflowService: bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key, bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key, bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key, - bitwarden_collection_id=parameter.bitwarden_collection_id, + # TODO: remove "# type: ignore" after ensuring bitwarden_collection_id is always set + bitwarden_collection_id=parameter.bitwarden_collection_id, # type: ignore bitwarden_identity_key=parameter.bitwarden_identity_key, bitwarden_identity_fields=parameter.bitwarden_identity_fields, key=parameter.key, @@ -1046,6 +1053,8 @@ class WorkflowService: max_retries=block_yaml.max_retries, complete_on_download=block_yaml.complete_on_download, continue_on_failure=block_yaml.continue_on_failure, + totp_verification_url=block_yaml.totp_verification_url, + totp_identifier=block_yaml.totp_identifier, ) elif block_yaml.block_type == BlockType.FOR_LOOP: loop_blocks = [ diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index 18668579..1c538625 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -1931,7 +1931,13 @@ async def get_input_value(tag_name: str, locator: Locator) -> str | None: return await locator.inner_text() -async def poll_verification_code(task_id: str, organization_id: str, url: str) -> str | None: +async def poll_verification_code( + task_id: str, + organization_id: str, + workflow_id: str | None = None, + totp_verification_url: str | None = None, + totp_identifier: str | None = None, +) -> str | None: timeout = timedelta(minutes=VERIFICATION_CODE_POLLING_TIMEOUT_MINS) start_datetime = datetime.utcnow() timeout_datetime = start_datetime + timeout @@ -1943,24 +1949,52 @@ async def poll_verification_code(task_id: str, organization_id: str, url: str) - # check timeout if datetime.utcnow() > timeout_datetime: return None - request_data = { - "task_id": task_id, - } - payload = json.dumps(request_data) - signature = generate_skyvern_signature( - payload=payload, - api_key=org_token.token, - ) - timestamp = str(int(datetime.utcnow().timestamp())) - headers = { - "x-skyvern-timestamp": timestamp, - "x-skyvern-signature": signature, - "Content-Type": "application/json", - } - json_resp = await aiohttp_post(url=url, data=request_data, headers=headers, raise_exception=False) - verification_code = json_resp.get("verification_code", None) + verification_code = None + if totp_verification_url: + verification_code = await _get_verification_code_from_url(task_id, totp_verification_url, org_token.token) + elif totp_identifier: + verification_code = await _get_verification_code_from_db( + task_id, organization_id, totp_identifier, workflow_id=workflow_id + ) if verification_code: LOG.info("Got verification code", verification_code=verification_code) return verification_code await asyncio.sleep(10) + + +async def _get_verification_code_from_url(task_id: str, url: str, api_key: str) -> str | None: + request_data = { + "task_id": task_id, + } + payload = json.dumps(request_data) + signature = generate_skyvern_signature( + payload=payload, + api_key=api_key, + ) + timestamp = str(int(datetime.utcnow().timestamp())) + headers = { + "x-skyvern-timestamp": timestamp, + "x-skyvern-signature": signature, + "Content-Type": "application/json", + } + json_resp = await aiohttp_post(url=url, data=request_data, headers=headers, raise_exception=False) + return json_resp.get("verification_code", None) + + +async def _get_verification_code_from_db( + task_id: str, + organization_id: str, + totp_identifier: str, + workflow_id: str | None = None, +) -> str | None: + totp_codes = await app.DATABASE.get_totp_codes(organization_id=organization_id, totp_identifier=totp_identifier) + for totp_code in totp_codes: + if totp_code.workflow_id and workflow_id and totp_code.workflow_id != workflow_id: + continue + if totp_code.task_id and totp_code.task_id != task_id: + continue + if totp_code.expired_at and totp_code.expired_at < datetime.utcnow(): + continue + return totp_code.code + return None