workflow apis (#326)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-05-16 10:51:22 -07:00
committed by GitHub
parent 50026f33c2
commit 72d25cd37d
9 changed files with 364 additions and 19 deletions

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from typing import Any, Sequence
import structlog
from sqlalchemy import and_, delete, select
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
@@ -723,10 +723,14 @@ class AgentDB:
async def create_workflow(
self,
organization_id: str,
title: str,
workflow_definition: dict[str, Any],
organization_id: str | None = None,
description: str | None = None,
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
) -> Workflow:
async with self.Session() as session:
workflow = WorkflowModel(
@@ -734,7 +738,13 @@ class AgentDB:
title=title,
description=description,
workflow_definition=workflow_definition,
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
)
if workflow_permanent_id:
workflow.workflow_permanent_id = workflow_permanent_id
if version:
workflow.version = version
session.add(workflow)
await session.commit()
await session.refresh(workflow)
@@ -743,7 +753,9 @@ class AgentDB:
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None:
try:
async with self.Session() as session:
get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id)
get_workflow_query = (
select(WorkflowModel).filter_by(workflow_id=workflow_id).filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
@@ -753,6 +765,74 @@ class AgentDB:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
version: int | None = None,
) -> Workflow | None:
try:
get_workflow_query = (
select(WorkflowModel)
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if version:
get_workflow_query = get_workflow_query.filter_by(version=version)
get_workflow_query = get_workflow_query.order_by(WorkflowModel.version.desc())
async with self.Session() as session:
if workflow := (await session.scalars(get_workflow_query)).first():
return convert_to_workflow(workflow, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflows_by_organization_id(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
db_page = page - 1
try:
async with self.Session() as session:
subquery = (
select(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(WorkflowModel.organization_id, WorkflowModel.workflow_permanent_id)
.subquery()
)
main_query = (
select(WorkflowModel)
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.order_by(WorkflowModel.created_at.desc()) # Example ordering by creation date
.limit(page_size)
.offset(db_page * page_size)
)
workflows = (await session.scalars(main_query)).all()
return [convert_to_workflow(workflow, self.debug_enabled) for workflow in workflows]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow(
self,
workflow_id: str,
@@ -760,10 +840,13 @@ class AgentDB:
title: str | None = None,
description: str | None = None,
workflow_definition: dict[str, Any] | None = None,
version: int | None = None,
) -> Workflow:
try:
async with self.Session() as session:
get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id)
get_workflow_query = (
select(WorkflowModel).filter_by(workflow_id=workflow_id).filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
@@ -773,6 +856,8 @@ class AgentDB:
workflow.description = description
if workflow_definition:
workflow.workflow_definition = workflow_definition
if version:
workflow.version = version
await session.commit()
await session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
@@ -789,8 +874,29 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True)
raise
async def soft_delete_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
) -> None:
async with self.Session() as session:
# soft delete the workflow by setting the deleted_at field
update_deleted_at_query = (
update(WorkflowModel)
.where(WorkflowModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
update_deleted_at_query = update_deleted_at_query.filter_by(organization_id=organization_id)
update_deleted_at_query = update_deleted_at_query.values(deleted_at=datetime.utcnow())
await session.execute(update_deleted_at_query)
await session.commit()
async def create_workflow_run(
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
self,
workflow_id: str,
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
) -> WorkflowRun:
try:
async with self.Session() as session: