From 72d25cd37d76b28fd121ff8d7d8d2899f289001d Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Thu, 16 May 2024 10:51:22 -0700 Subject: [PATCH] workflow apis (#326) Co-authored-by: Shuchang Zheng --- ...dd_proxy_location_and_webhook_callback_.py | 49 ++++++++ skyvern/exceptions.py | 18 ++- skyvern/forge/sdk/db/client.py | 116 +++++++++++++++++- skyvern/forge/sdk/db/models.py | 2 + skyvern/forge/sdk/db/utils.py | 4 + skyvern/forge/sdk/routes/agent_protocol.py | 85 +++++++++++++ skyvern/forge/sdk/workflow/models/workflow.py | 4 + skyvern/forge/sdk/workflow/models/yaml.py | 3 + skyvern/forge/sdk/workflow/service.py | 102 +++++++++++++-- 9 files changed, 364 insertions(+), 19 deletions(-) create mode 100644 alembic/versions/2024_05_16_1729-04bf06540db6_add_proxy_location_and_webhook_callback_.py diff --git a/alembic/versions/2024_05_16_1729-04bf06540db6_add_proxy_location_and_webhook_callback_.py b/alembic/versions/2024_05_16_1729-04bf06540db6_add_proxy_location_and_webhook_callback_.py new file mode 100644 index 00000000..ea15052a --- /dev/null +++ b/alembic/versions/2024_05_16_1729-04bf06540db6_add_proxy_location_and_webhook_callback_.py @@ -0,0 +1,49 @@ +"""add proxy_location and webhook_callback_url to workflows table + +Revision ID: 04bf06540db6 +Revises: baec12642d77 +Create Date: 2024-05-16 17:29:55.083124+00:00 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "04bf06540db6" +down_revision: Union[str, None] = "baec12642d77" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "workflows", + sa.Column( + "proxy_location", + sa.Enum( + "US_CA", + "US_NY", + "US_TX", + "US_FL", + "US_WA", + "RESIDENTIAL", + "RESIDENTIAL_ES", + "NONE", + name="proxylocation", + ), + nullable=True, + ), + ) + op.add_column("workflows", sa.Column("webhook_callback_url", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("workflows", "webhook_callback_url") + op.drop_column("workflows", "proxy_location") + # ### end Alembic commands ### diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 431951eb..0d08d6c4 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -107,8 +107,22 @@ class UnknownBlockType(SkyvernException): class WorkflowNotFound(SkyvernHTTPException): - def __init__(self, workflow_id: str) -> None: - super().__init__(f"Workflow {workflow_id} not found", status_code=status.HTTP_404_NOT_FOUND) + def __init__( + self, + workflow_id: str | None = None, + workflow_permanent_id: str | None = None, + version: int | None = None, + ) -> None: + workflow_repr = "" + if workflow_id: + workflow_repr = f"workflow_id={workflow_id}" + if workflow_permanent_id: + if version: + workflow_repr = f"workflow_permanent_id={workflow_permanent_id}, version={version}" + else: + workflow_repr = f"workflow_permanent_id={workflow_permanent_id}" + + super().__init__(f"Workflow not found. {workflow_repr}", status_code=status.HTTP_404_NOT_FOUND) class WorkflowRunNotFound(SkyvernException): diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 50ed1083..a16e0611 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Any, Sequence import structlog -from sqlalchemy import and_, delete, select +from sqlalchemy import and_, delete, func, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine @@ -723,10 +723,14 @@ class AgentDB: async def create_workflow( self, - organization_id: str, title: str, workflow_definition: dict[str, Any], + organization_id: str | None = None, description: str | None = None, + proxy_location: ProxyLocation | None = None, + webhook_callback_url: str | None = None, + workflow_permanent_id: str | None = None, + version: int | None = None, ) -> Workflow: async with self.Session() as session: workflow = WorkflowModel( @@ -734,7 +738,13 @@ class AgentDB: title=title, description=description, workflow_definition=workflow_definition, + proxy_location=proxy_location, + webhook_callback_url=webhook_callback_url, ) + if workflow_permanent_id: + workflow.workflow_permanent_id = workflow_permanent_id + if version: + workflow.version = version session.add(workflow) await session.commit() await session.refresh(workflow) @@ -743,7 +753,9 @@ class AgentDB: async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None: try: async with self.Session() as session: - get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id) + get_workflow_query = ( + select(WorkflowModel).filter_by(workflow_id=workflow_id).filter(WorkflowModel.deleted_at.is_(None)) + ) if organization_id: get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) if workflow := (await session.scalars(get_workflow_query)).first(): @@ -753,6 +765,74 @@ class AgentDB: LOG.error("SQLAlchemyError", exc_info=True) raise + async def get_workflow_by_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + version: int | None = None, + ) -> Workflow | None: + try: + get_workflow_query = ( + select(WorkflowModel) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .filter(WorkflowModel.deleted_at.is_(None)) + ) + if organization_id: + get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) + if version: + get_workflow_query = get_workflow_query.filter_by(version=version) + get_workflow_query = get_workflow_query.order_by(WorkflowModel.version.desc()) + async with self.Session() as session: + if workflow := (await session.scalars(get_workflow_query)).first(): + return convert_to_workflow(workflow, self.debug_enabled) + return None + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + + async def get_workflows_by_organization_id( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + ) -> list[Workflow]: + """ + Get all workflows with the latest version for the organization. + """ + if page < 1: + raise ValueError(f"Page must be greater than 0, got {page}") + db_page = page - 1 + try: + async with self.Session() as session: + subquery = ( + select( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by(WorkflowModel.organization_id, WorkflowModel.workflow_permanent_id) + .subquery() + ) + main_query = ( + select(WorkflowModel) + .join( + subquery, + (WorkflowModel.organization_id == subquery.c.organization_id) + & (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + .order_by(WorkflowModel.created_at.desc()) # Example ordering by creation date + .limit(page_size) + .offset(db_page * page_size) + ) + workflows = (await session.scalars(main_query)).all() + return [convert_to_workflow(workflow, self.debug_enabled) for workflow in workflows] + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + async def update_workflow( self, workflow_id: str, @@ -760,10 +840,13 @@ class AgentDB: title: str | None = None, description: str | None = None, workflow_definition: dict[str, Any] | None = None, + version: int | None = None, ) -> Workflow: try: async with self.Session() as session: - get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id) + get_workflow_query = ( + select(WorkflowModel).filter_by(workflow_id=workflow_id).filter(WorkflowModel.deleted_at.is_(None)) + ) if organization_id: get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) if workflow := (await session.scalars(get_workflow_query)).first(): @@ -773,6 +856,8 @@ class AgentDB: workflow.description = description if workflow_definition: workflow.workflow_definition = workflow_definition + if version: + workflow.version = version await session.commit() await session.refresh(workflow) return convert_to_workflow(workflow, self.debug_enabled) @@ -789,8 +874,29 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise + async def soft_delete_workflow_by_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + ) -> None: + async with self.Session() as session: + # soft delete the workflow by setting the deleted_at field + update_deleted_at_query = ( + update(WorkflowModel) + .where(WorkflowModel.workflow_permanent_id == workflow_permanent_id) + .where(WorkflowModel.deleted_at.is_(None)) + ) + if organization_id: + update_deleted_at_query = update_deleted_at_query.filter_by(organization_id=organization_id) + update_deleted_at_query = update_deleted_at_query.values(deleted_at=datetime.utcnow()) + await session.execute(update_deleted_at_query) + await session.commit() + async def create_workflow_run( - self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None + self, + workflow_id: str, + proxy_location: ProxyLocation | None = None, + webhook_callback_url: str | None = None, ) -> WorkflowRun: try: async with self.Session() as session: diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 655c9491..5d80c6c6 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -145,6 +145,8 @@ class WorkflowModel(Base): title = Column(String, nullable=False) description = Column(String, nullable=True) workflow_definition = Column(JSON, nullable=False) + proxy_location = Column(Enum(ProxyLocation)) + webhook_callback_url = Column(String) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 23925fd3..85546c28 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -148,6 +148,10 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal workflow_id=workflow_model.workflow_id, organization_id=workflow_model.organization_id, title=workflow_model.title, + workflow_permanent_id=workflow_model.workflow_permanent_id, + webhook_callback_url=workflow_model.webhook_callback_url, + proxy_location=ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None, + version=workflow_model.version, description=workflow_model.description, workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition), created_at=workflow_model.created_at, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 26d91b35..3cf43505 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -532,3 +532,88 @@ async def create_workflow( return await app.WORKFLOW_SERVICE.create_workflow_from_request( organization_id=current_org.organization_id, request=workflow_create_request ) + + +@base_router.put( + "/workflows/{workflow_permanent_id}", + openapi_extra={ + "requestBody": { + "content": {"application/x-yaml": {"schema": WorkflowCreateYAMLRequest.model_json_schema()}}, + "required": True, + }, + }, + response_model=Workflow, +) +@base_router.put( + "/workflows/{workflow_permanent_id}/", + openapi_extra={ + "requestBody": { + "content": {"application/x-yaml": {"schema": WorkflowCreateYAMLRequest.model_json_schema()}}, + "required": True, + }, + }, + response_model=Workflow, + include_in_schema=False, +) +async def update_workflow( + workflow_permanent_id: str, + request: Request, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> Workflow: + analytics.capture("skyvern-oss-agent-workflow-update") + # validate the workflow + raw_yaml = await request.body() + try: + workflow_yaml = yaml.safe_load(raw_yaml) + except yaml.YAMLError: + raise HTTPException(status_code=422, detail="Invalid YAML") + + workflow_create_request = WorkflowCreateYAMLRequest.model_validate(workflow_yaml) + return await app.WORKFLOW_SERVICE.create_workflow_from_request( + organization_id=current_org.organization_id, + request=workflow_create_request, + workflow_permanent_id=workflow_permanent_id, + ) + + +@base_router.delete("/workflows/{workflow_permanent_id}") +@base_router.delete("/workflows/{workflow_permanent_id}/", include_in_schema=False) +async def delete_workflow( + workflow_permanent_id: str, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> None: + analytics.capture("skyvern-oss-agent-workflow-delete") + await app.WORKFLOW_SERVICE.delete_workflow_by_permanent_id(workflow_permanent_id, current_org.organization_id) + + +@base_router.get("/workflows", response_model=list[Workflow]) +@base_router.get("/workflows/", response_model=list[Workflow]) +async def get_workflows( + page: int = Query(1, ge=1), + page_size: int = Query(10, ge=1), + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> list[Workflow]: + """ + Get all workflows with the latest version for the organization. + """ + analytics.capture("skyvern-oss-agent-workflows-get") + return await app.WORKFLOW_SERVICE.get_workflows_by_organization_id( + organization_id=current_org.organization_id, + page=page, + page_size=page_size, + ) + + +@base_router.get("/workflows/{workflow_permanent_id}", response_model=Workflow) +@base_router.get("/workflows/{workflow_permanent_id}/", response_model=Workflow) +async def get_workflow( + workflow_permanent_id: str, + version: int | None = None, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> Workflow: + analytics.capture("skyvern-oss-agent-workflows-get") + return await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( + workflow_permanent_id=workflow_permanent_id, + organization_id=current_org.organization_id, + version=version, + ) diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index bba4c09f..3aa9dd77 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -42,8 +42,12 @@ class Workflow(BaseModel): workflow_id: str organization_id: str title: str + workflow_permanent_id: str + version: int description: str | None = None workflow_definition: WorkflowDefinition + proxy_location: ProxyLocation | None = None + webhook_callback_url: str | None = None created_at: datetime modified_at: datetime diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index 1bec459b..fedfbc0a 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -3,6 +3,7 @@ from typing import Annotated, Any, Literal from pydantic import BaseModel, Field +from skyvern.forge.sdk.schemas.tasks import ProxyLocation from skyvern.forge.sdk.workflow.models.block import BlockType from skyvern.forge.sdk.workflow.models.parameter import ParameterType, WorkflowParameterType @@ -187,4 +188,6 @@ class WorkflowDefinitionYAML(BaseModel): class WorkflowCreateYAMLRequest(BaseModel): title: str description: str | None = None + proxy_location: ProxyLocation | None = None + webhook_callback_url: str | None = None workflow_definition: WorkflowDefinitionYAML diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index df16dd98..ad5d0b52 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -19,7 +19,7 @@ from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.models import Step -from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus +from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus from skyvern.forge.sdk.workflow.exceptions import ( ContextParameterSourceNotDefined, WorkflowDefinitionHasDuplicateParameterKeys, @@ -89,6 +89,10 @@ class WorkflowService: if workflow.organization_id != organization_id: LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}") raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id) + if workflow_request.proxy_location is None and workflow.proxy_location is not None: + workflow_request.proxy_location = workflow.proxy_location + if workflow_request.webhook_callback_url is None and workflow.webhook_callback_url is not None: + workflow_request.webhook_callback_url = workflow.webhook_callback_url # Create the workflow run and set skyvern context workflow_run = await self.create_workflow_run(workflow_request=workflow_request, workflow_id=workflow_id) LOG.info( @@ -97,6 +101,7 @@ class WorkflowService: workflow_run_id=workflow_run.workflow_run_id, workflow_id=workflow.workflow_id, proxy_location=workflow_request.proxy_location, + webhook_callback_url=workflow_request.webhook_callback_url, ) skyvern_context.set( SkyvernContext( @@ -266,20 +271,58 @@ class WorkflowService: title: str, workflow_definition: WorkflowDefinition, description: str | None = None, + proxy_location: ProxyLocation | None = None, + webhook_callback_url: str | None = None, + workflow_permanent_id: str | None = None, + version: int | None = None, ) -> Workflow: return await app.DATABASE.create_workflow( - organization_id=organization_id, title=title, - description=description, workflow_definition=workflow_definition.model_dump(), + organization_id=organization_id, + description=description, + proxy_location=proxy_location, + webhook_callback_url=webhook_callback_url, + workflow_permanent_id=workflow_permanent_id, + version=version, ) async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow: workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id) if not workflow: - raise WorkflowNotFound(workflow_id) + raise WorkflowNotFound(workflow_id=workflow_id) return workflow + async def get_workflow_by_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + version: int | None = None, + ) -> Workflow: + workflow = await app.DATABASE.get_workflow_by_permanent_id( + workflow_permanent_id, + organization_id=organization_id, + version=version, + ) + if not workflow: + raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id, version=version) + return workflow + + async def get_workflows_by_organization_id( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + ) -> list[Workflow]: + """ + Get all workflows with the latest version for the organization. + """ + return await app.DATABASE.get_workflows_by_organization_id( + organization_id=organization_id, + page=page, + page_size=page_size, + ) + async def update_workflow( self, workflow_id: str, @@ -290,14 +333,25 @@ class WorkflowService: ) -> Workflow: if workflow_definition: workflow_definition.validate() + return await app.DATABASE.update_workflow( workflow_id=workflow_id, - organization_id=organization_id, title=title, + organization_id=organization_id, description=description, workflow_definition=workflow_definition.model_dump() if workflow_definition else None, ) + async def delete_workflow_by_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + ) -> None: + await app.DATABASE.soft_delete_workflow_by_permanent_id( + workflow_permanent_id=workflow_permanent_id, + organization_id=organization_id, + ) + async def create_workflow_run(self, workflow_request: WorkflowRequestBody, workflow_id: str) -> WorkflowRun: return await app.DATABASE.create_workflow_run( workflow_id=workflow_id, @@ -669,15 +723,39 @@ class WorkflowService: await self.persist_har_data(browser_state, last_step, workflow, workflow_run) await self.persist_tracing_data(browser_state, last_step, workflow_run) - async def create_workflow_from_request(self, organization_id: str, request: WorkflowCreateYAMLRequest) -> Workflow: + async def create_workflow_from_request( + self, + organization_id: str, + request: WorkflowCreateYAMLRequest, + workflow_permanent_id: str | None = None, + ) -> Workflow: LOG.info("Creating workflow from request", organization_id=organization_id, title=request.title) try: - workflow = await self.create_workflow( - organization_id=organization_id, - title=request.title, - description=request.description, - workflow_definition=WorkflowDefinition(parameters=[], blocks=[]), - ) + if workflow_permanent_id: + existing_latest_workflow = await self.get_workflow_by_permanent_id( + workflow_permanent_id=workflow_permanent_id, + organization_id=organization_id, + ) + existing_version = existing_latest_workflow.version + workflow = await self.create_workflow( + title=request.title, + workflow_definition=WorkflowDefinition(parameters=[], blocks=[]), + description=request.description, + organization_id=organization_id, + proxy_location=request.proxy_location, + webhook_callback_url=request.webhook_callback_url, + workflow_permanent_id=workflow_permanent_id, + version=existing_version + 1, + ) + else: + workflow = await self.create_workflow( + title=request.title, + workflow_definition=WorkflowDefinition(parameters=[], blocks=[]), + description=request.description, + organization_id=organization_id, + proxy_location=request.proxy_location, + webhook_callback_url=request.webhook_callback_url, + ) # Create parameters from the request parameters: dict[str, PARAMETER_TYPE] = {} duplicate_parameter_keys = set()