TOTP code db + agent support for fetching totp_code from db (#784)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-09-08 15:07:03 -07:00
committed by GitHub
parent d878ee5a0d
commit b9f5e33876
14 changed files with 243 additions and 26 deletions

View File

@@ -22,6 +22,7 @@ from skyvern.forge.sdk.db.models import (
StepModel,
TaskGenerationModel,
TaskModel,
TOTPCodeModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
@@ -48,6 +49,7 @@ from skyvern.forge.sdk.db.utils import (
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.schemas.totp_codes import TOTPCode
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
BitwardenLoginCredentialParameter,
@@ -84,6 +86,7 @@ class AgentDB:
navigation_payload: dict[str, Any] | list | str | None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocation | None = None,
extracted_information_schema: dict[str, Any] | list | str | None = None,
@@ -101,6 +104,7 @@ class AgentDB:
title=title,
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
navigation_goal=navigation_goal,
data_extraction_goal=data_extraction_goal,
navigation_payload=navigation_payload,
@@ -819,6 +823,7 @@ class AgentDB:
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
persist_browser_session: bool = False,
workflow_permanent_id: str | None = None,
version: int | None = None,
@@ -833,6 +838,7 @@ class AgentDB:
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
persist_browser_session=persist_browser_session,
is_saved_task=is_saved_task,
)
@@ -1001,6 +1007,7 @@ class AgentDB:
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
) -> WorkflowRun:
try:
async with self.Session() as session:
@@ -1012,6 +1019,7 @@ class AgentDB:
status="created",
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
session.add(workflow_run)
await session.commit()
@@ -1439,3 +1447,27 @@ class AgentDB:
if not task_generation:
return None
return TaskGeneration.model_validate(task_generation)
async def get_totp_codes(
self,
organization_id: str,
totp_identifier: str,
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
) -> list[TOTPCode]:
"""
1. filter by:
- organization_id
- totp_identifier
2. make sure created_at is within the valid lifespan
3. sort by created_at desc
"""
async with self.Session() as session:
query = (
select(TOTPCodeModel)
.filter_by(organization_id=organization_id)
.filter_by(totp_identifier=totp_identifier)
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
.order_by(TOTPCodeModel.created_at.desc())
)
totp_code = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]

View File

@@ -119,6 +119,11 @@ def generate_task_generation_id() -> str:
return f"{TASK_GENERATION_PREFIX}_{int_id}"
def generate_totp_code_id() -> str:
int_id = generate_id()
return f"totp_{int_id}"
def generate_id() -> int:
"""
generate a 64-bit int ID

View File

@@ -29,6 +29,7 @@ from skyvern.forge.sdk.db.id import (
generate_step_id,
generate_task_generation_id,
generate_task_id,
generate_totp_code_id,
generate_workflow_id,
generate_workflow_parameter_id,
generate_workflow_permanent_id,
@@ -49,6 +50,7 @@ class TaskModel(Base):
status = Column(String, index=True)
webhook_callback_url = Column(String)
totp_verification_url = Column(String)
totp_identifier = Column(String)
title = Column(String)
url = Column(String)
navigation_goal = Column(String)
@@ -180,6 +182,7 @@ class WorkflowModel(Base):
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)
totp_verification_url = Column(String)
totp_identifier = Column(String)
persist_browser_session = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
@@ -207,6 +210,7 @@ class WorkflowRunModel(Base):
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)
totp_verification_url = Column(String)
totp_identifier = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
@@ -392,3 +396,19 @@ class TaskGenerationModel(Base):
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class TOTPCodeModel(Base):
__tablename__ = "totp_codes"
totp_code_id = Column(String, primary_key=True, default=generate_totp_code_id)
totp_identifier = Column(String, nullable=False, index=True)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
task_id = Column(String, ForeignKey("tasks.task_id"))
workflow_id = Column(String, ForeignKey("workflows.workflow_id"))
content = Column(String, nullable=False)
code = Column(String, nullable=False)
source = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
expired_at = Column(DateTime, index=True)

View File

@@ -64,6 +64,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
url=task_obj.url,
webhook_callback_url=task_obj.webhook_callback_url,
totp_verification_url=task_obj.totp_verification_url,
totp_identifier=task_obj.totp_identifier,
navigation_goal=task_obj.navigation_goal,
data_extraction_goal=task_obj.data_extraction_goal,
navigation_payload=task_obj.navigation_payload,
@@ -162,6 +163,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
workflow_permanent_id=workflow_model.workflow_permanent_id,
webhook_callback_url=workflow_model.webhook_callback_url,
totp_verification_url=workflow_model.totp_verification_url,
totp_identifier=workflow_model.totp_identifier,
persist_browser_session=workflow_model.persist_browser_session,
proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None),
version=workflow_model.version,
@@ -192,6 +194,7 @@ def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled:
),
webhook_callback_url=workflow_run_model.webhook_callback_url,
totp_verification_url=workflow_run_model.totp_verification_url,
totp_identifier=workflow_run_model.totp_identifier,
created_at=workflow_run_model.created_at,
modified_at=workflow_run_model.modified_at,
)