From 63adc860efdf3635a58b88b4e1c4462c81069cce Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Thu, 27 Jun 2024 12:53:08 -0700 Subject: [PATCH] is_saved_task parameter for workflows (#526) --- ...67adef01_add_is_saved_task_to_workflows.py | 31 +++++++++++++++++++ skyvern/forge/sdk/db/client.py | 25 +++++++++------ skyvern/forge/sdk/db/models.py | 1 + skyvern/forge/sdk/db/utils.py | 1 + skyvern/forge/sdk/routes/agent_protocol.py | 11 +++++++ skyvern/forge/sdk/workflow/models/workflow.py | 1 + skyvern/forge/sdk/workflow/models/yaml.py | 1 + skyvern/forge/sdk/workflow/service.py | 8 +++++ 8 files changed, 69 insertions(+), 10 deletions(-) create mode 100644 alembic/versions/2024_06_27_1949-485667adef01_add_is_saved_task_to_workflows.py diff --git a/alembic/versions/2024_06_27_1949-485667adef01_add_is_saved_task_to_workflows.py b/alembic/versions/2024_06_27_1949-485667adef01_add_is_saved_task_to_workflows.py new file mode 100644 index 00000000..e6de82c2 --- /dev/null +++ b/alembic/versions/2024_06_27_1949-485667adef01_add_is_saved_task_to_workflows.py @@ -0,0 +1,31 @@ +"""Add is_saved_task to workflows + +Revision ID: 485667adef01 +Revises: 2c163e606a3d +Create Date: 2024-06-27 19:49:41.506447+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "485667adef01" +down_revision: Union[str, None] = "2c163e606a3d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("workflows", sa.Column("is_saved_task", sa.Boolean(), server_default=sa.false(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("workflows", "is_saved_task") + # ### end Alembic commands ### diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 0fc00c43..70140a79 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -774,6 +774,7 @@ class AgentDB: webhook_callback_url: str | None = None, workflow_permanent_id: str | None = None, version: int | None = None, + is_saved_task: bool = False, ) -> Workflow: async with self.Session() as session: workflow = WorkflowModel( @@ -783,6 +784,7 @@ class AgentDB: workflow_definition=workflow_definition, proxy_location=proxy_location, webhook_callback_url=webhook_callback_url, + is_saved_task=is_saved_task, ) if workflow_permanent_id: workflow.workflow_permanent_id = workflow_permanent_id @@ -838,6 +840,8 @@ class AgentDB: organization_id: str, page: int = 1, page_size: int = 10, + only_saved_tasks: bool = False, + only_workflows: bool = False, ) -> list[Workflow]: """ Get all workflows with the latest version for the organization. @@ -861,17 +865,18 @@ class AgentDB: ) .subquery() ) + main_query = select(WorkflowModel).join( + subquery, + (WorkflowModel.organization_id == subquery.c.organization_id) + & (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + if only_saved_tasks: + main_query = main_query.where(WorkflowModel.is_saved_task.is_(True)) + elif only_workflows: + main_query = main_query.where(WorkflowModel.is_saved_task.is_(False)) main_query = ( - select(WorkflowModel) - .join( - subquery, - (WorkflowModel.organization_id == subquery.c.organization_id) - & (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) - & (WorkflowModel.version == subquery.c.max_version), - ) - .order_by(WorkflowModel.created_at.desc()) # Example ordering by creation date - .limit(page_size) - .offset(db_page * page_size) + main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size) ) workflows = (await session.scalars(main_query)).all() return [convert_to_workflow(workflow, self.debug_enabled) for workflow in workflows] diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 47e5d97d..218af1a3 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -189,6 +189,7 @@ class WorkflowModel(Base): workflow_permanent_id = Column(String, nullable=False, default=generate_workflow_permanent_id, index=True) version = Column(Integer, default=1, nullable=False) + is_saved_task = Column(Boolean, default=False, nullable=False) class WorkflowRunModel(Base): diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 582eb8a6..6bd00b74 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -160,6 +160,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal webhook_callback_url=workflow_model.webhook_callback_url, proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None), version=workflow_model.version, + is_saved_task=workflow_model.is_saved_task, description=workflow_model.description, workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition), created_at=workflow_model.created_at, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index cab8126f..de560261 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -654,16 +654,27 @@ async def delete_workflow( async def get_workflows( page: int = Query(1, ge=1), page_size: int = Query(10, ge=1), + only_saved_tasks: bool = Query(False), + only_workflows: bool = Query(False), current_org: Organization = Depends(org_auth_service.get_current_org), ) -> list[Workflow]: """ Get all workflows with the latest version for the organization. """ analytics.capture("skyvern-oss-agent-workflows-get") + + if only_saved_tasks and only_workflows: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="only_saved_tasks and only_workflows cannot be used together", + ) + return await app.WORKFLOW_SERVICE.get_workflows_by_organization_id( organization_id=current_org.organization_id, page=page, page_size=page_size, + only_saved_tasks=only_saved_tasks, + only_workflows=only_workflows, ) diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index 25c0b8ec..c0371383 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -44,6 +44,7 @@ class Workflow(BaseModel): title: str workflow_permanent_id: str version: int + is_saved_task: bool description: str | None = None workflow_definition: WorkflowDefinition proxy_location: ProxyLocation | None = None diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index 9bd0a113..b1769485 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -194,3 +194,4 @@ class WorkflowCreateYAMLRequest(BaseModel): proxy_location: ProxyLocation | None = None webhook_callback_url: str | None = None workflow_definition: WorkflowDefinitionYAML + is_saved_task: bool = False diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index aaebe99b..87d145c2 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -281,6 +281,7 @@ class WorkflowService: webhook_callback_url: str | None = None, workflow_permanent_id: str | None = None, version: int | None = None, + is_saved_task: bool = False, ) -> Workflow: return await app.DATABASE.create_workflow( title=title, @@ -291,6 +292,7 @@ class WorkflowService: webhook_callback_url=webhook_callback_url, workflow_permanent_id=workflow_permanent_id, version=version, + is_saved_task=is_saved_task, ) async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow: @@ -319,6 +321,8 @@ class WorkflowService: organization_id: str, page: int = 1, page_size: int = 10, + only_saved_tasks: bool = False, + only_workflows: bool = False, ) -> list[Workflow]: """ Get all workflows with the latest version for the organization. @@ -327,6 +331,8 @@ class WorkflowService: organization_id=organization_id, page=page, page_size=page_size, + only_saved_tasks=only_saved_tasks, + only_workflows=only_workflows, ) async def update_workflow( @@ -773,6 +779,7 @@ class WorkflowService: webhook_callback_url=request.webhook_callback_url, workflow_permanent_id=workflow_permanent_id, version=existing_version + 1, + is_saved_task=request.is_saved_task, ) else: workflow = await self.create_workflow( @@ -782,6 +789,7 @@ class WorkflowService: organization_id=organization_id, proxy_location=request.proxy_location, webhook_callback_url=request.webhook_callback_url, + is_saved_task=request.is_saved_task, ) # Create parameters from the request parameters: dict[str, PARAMETER_TYPE] = {}