From 10612f02fd37b1e58fbd04323a896647e2679284 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Sun, 16 Jun 2024 19:42:20 -0700 Subject: [PATCH] update organization API (#480) --- skyvern/forge/api_app.py | 5 +++++ skyvern/forge/sdk/db/client.py | 26 ++++++++++++++++++++++ skyvern/forge/sdk/db/models.py | 4 ++-- skyvern/forge/sdk/models.py | 4 +++- skyvern/forge/sdk/routes/agent_protocol.py | 16 +++++++++++++ skyvern/forge/sdk/schemas/organizations.py | 8 +++++++ 6 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 skyvern/forge/sdk/schemas/organizations.py diff --git a/skyvern/forge/api_app.py b/skyvern/forge/api_app.py index 062644b5..c51dbc21 100644 --- a/skyvern/forge/api_app.py +++ b/skyvern/forge/api_app.py @@ -15,6 +15,7 @@ from skyvern.exceptions import SkyvernHTTPException from skyvern.forge import app as forge_app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext +from skyvern.forge.sdk.db.exceptions import NotFoundError from skyvern.forge.sdk.routes.agent_protocol import base_router from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.scheduler import SCHEDULER @@ -75,6 +76,10 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI: LOG.info("Server startup complete. Skyvern is now online") + @app.exception_handler(NotFoundError) + async def handle_not_found_error(request: Request, exc: NotFoundError) -> Response: + return Response(status_code=status.HTTP_404_NOT_FOUND) + @app.exception_handler(SkyvernHTTPException) async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse: return JSONResponse(status_code=exc.status_code, content={"detail": exc.message}) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index a2463788..0fc00c43 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -485,6 +485,32 @@ class AgentDB: return convert_to_organization(org) + async def update_organization( + self, + organization_id: str, + organization_name: str | None = None, + webhook_callback_url: str | None = None, + max_steps_per_run: int | None = None, + max_retries_per_step: int | None = None, + ) -> Organization: + async with self.Session() as session: + organization = ( + await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id)) + ).first() + if not organization: + raise NotFoundError + if organization_name: + organization.organization_name = organization_name + if webhook_callback_url: + organization.webhook_callback_url = webhook_callback_url + if max_steps_per_run: + organization.max_steps_per_run = max_steps_per_run + if max_retries_per_step: + organization.max_retries_per_step = max_retries_per_step + await session.commit() + await session.refresh(organization) + return Organization.model_validate(organization) + async def get_valid_org_auth_token( self, organization_id: str, diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index a3b02e4a..47e5d97d 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -109,7 +109,7 @@ class OrganizationModel(Base): modified_at = Column( DateTime, default=datetime.datetime.utcnow, - onupdate=datetime.datetime, + onupdate=datetime.datetime.utcnow, nullable=False, ) @@ -133,7 +133,7 @@ class OrganizationAuthTokenModel(Base): modified_at = Column( DateTime, default=datetime.datetime.utcnow, - onupdate=datetime.datetime, + onupdate=datetime.datetime.utcnow, nullable=False, ) deleted_at = Column(DateTime, nullable=True) diff --git a/skyvern/forge/sdk/models.py b/skyvern/forge/sdk/models.py index 75d307f4..a7feca1e 100644 --- a/skyvern/forge/sdk/models.py +++ b/skyvern/forge/sdk/models.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime from enum import StrEnum -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.webeye.actions.actions import ActionType @@ -118,6 +118,8 @@ class Step(BaseModel): class Organization(BaseModel): + model_config = ConfigDict(from_attributes=True) + organization_id: str organization_name: str webhook_callback_url: str | None = None diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 307036b7..be2e897e 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -17,6 +17,7 @@ from skyvern.forge.sdk.core.permissions.permission_checker_factory import Permis from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.models import Organization, Step +from skyvern.forge.sdk.schemas.organizations import OrganizationUpdate from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase from skyvern.forge.sdk.schemas.tasks import ( CreateTaskResponse, @@ -693,3 +694,18 @@ async def generate_task( except LLMProviderError: LOG.error("Failed to generate task", exc_info=True) raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.") + + +@base_router.put("/organizations/", include_in_schema=False) +@base_router.put("/organizations") +async def update_organization( + org_update: OrganizationUpdate, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> Organization: + return await app.DATABASE.update_organization( + current_org.organization_id, + organization_name=org_update.organization_name, + webhook_callback_url=org_update.webhook_callback_url, + max_steps_per_run=org_update.max_steps_per_run, + max_retries_per_step=org_update.max_retries_per_step, + ) diff --git a/skyvern/forge/sdk/schemas/organizations.py b/skyvern/forge/sdk/schemas/organizations.py new file mode 100644 index 00000000..ae67a53f --- /dev/null +++ b/skyvern/forge/sdk/schemas/organizations.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class OrganizationUpdate(BaseModel): + organization_name: str | None = None + webhook_callback_url: str | None = None + max_steps_per_run: int | None = None + max_retries_per_step: int | None = None