add organization_id to workflow service and db query (#320)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-05-15 08:43:36 -07:00
committed by GitHub
parent 6110fa4a44
commit 164a4da03a
2 changed files with 25 additions and 14 deletions

View File

@@ -740,12 +740,13 @@ class AgentDB:
await session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
async def get_workflow(self, workflow_id: str) -> Workflow | None:
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None:
try:
async with self.Session() as session:
if workflow := (
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))
).first():
get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
return convert_to_workflow(workflow, self.debug_enabled)
return None
except SQLAlchemyError:
@@ -755,15 +756,17 @@ class AgentDB:
async def update_workflow(
self,
workflow_id: str,
organization_id: str | None = None,
title: str | None = None,
description: str | None = None,
workflow_definition: dict[str, Any] | None = None,
) -> Workflow:
try:
async with self.Session() as session:
if workflow := (
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))
).first():
get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
if title:
workflow.title = title
if description: