diff --git a/alembic/versions/2024_04_23_2153-24303f1669a7_add_domain_to_organizations_table.py b/alembic/versions/2024_04_23_2153-24303f1669a7_add_domain_to_organizations_table.py new file mode 100644 index 00000000..55497827 --- /dev/null +++ b/alembic/versions/2024_04_23_2153-24303f1669a7_add_domain_to_organizations_table.py @@ -0,0 +1,32 @@ +"""add domain to organizations table + +Revision ID: 24303f1669a7 +Revises: 8335d7fecef9 +Create Date: 2024-04-23 21:53:45.475199+00:00 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "24303f1669a7" +down_revision: Union[str, None] = "8335d7fecef9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("organizations", sa.Column("domain", sa.String(), nullable=True)) + op.create_index(op.f("ix_organizations_domain"), "organizations", ["domain"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_organizations_domain"), table_name="organizations") + op.drop_column("organizations", "domain") + # ### end Alembic commands ### diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 15e05c64..0b0422be 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -436,12 +436,19 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise + async def get_organization_by_domain(self, domain: str) -> Organization | None: + async with self.Session() as session: + if organization := (await session.scalars(select(OrganizationModel).filter_by(domain=domain))).first(): + return convert_to_organization(organization) + return None + async def create_organization( self, organization_name: str, webhook_callback_url: str | None = None, max_steps_per_run: int | None = None, max_retries_per_step: int | None = None, + domain: str | None = None, ) -> Organization: async with self.Session() as session: org = OrganizationModel( @@ -449,6 +456,7 @@ class AgentDB: webhook_callback_url=webhook_callback_url, max_steps_per_run=max_steps_per_run, max_retries_per_step=max_retries_per_step, + domain=domain, ) session.add(org) await session.commit() diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index d3bf48bf..aa681ff4 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -78,6 +78,7 @@ class OrganizationModel(Base): webhook_callback_url = Column(UnicodeText) max_steps_per_run = Column(Integer, nullable=True) max_retries_per_step = Column(Integer, nullable=True) + domain = Column(String, nullable=True, index=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 5f6bfe21..d4f78637 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -105,6 +105,7 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization: webhook_callback_url=org_model.webhook_callback_url, max_steps_per_run=org_model.max_steps_per_run, max_retries_per_step=org_model.max_retries_per_step, + domain=org_model.domain, created_at=org_model.created_at, modified_at=org_model.modified_at, ) diff --git a/skyvern/forge/sdk/models.py b/skyvern/forge/sdk/models.py index 2d6bff41..643bc999 100644 --- a/skyvern/forge/sdk/models.py +++ b/skyvern/forge/sdk/models.py @@ -118,6 +118,7 @@ class Organization(BaseModel): webhook_callback_url: str | None = None max_steps_per_run: int | None = None max_retries_per_step: int | None = None + domain: str | None = None created_at: datetime modified_at: datetime