diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 429b01ee..50ed1083 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -740,12 +740,13 @@ class AgentDB: await session.refresh(workflow) return convert_to_workflow(workflow, self.debug_enabled) - async def get_workflow(self, workflow_id: str) -> Workflow | None: + async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None: try: async with self.Session() as session: - if workflow := ( - await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id)) - ).first(): + get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id) + if organization_id: + get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) + if workflow := (await session.scalars(get_workflow_query)).first(): return convert_to_workflow(workflow, self.debug_enabled) return None except SQLAlchemyError: @@ -755,15 +756,17 @@ class AgentDB: async def update_workflow( self, workflow_id: str, + organization_id: str | None = None, title: str | None = None, description: str | None = None, workflow_definition: dict[str, Any] | None = None, ) -> Workflow: try: async with self.Session() as session: - if workflow := ( - await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id)) - ).first(): + get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id) + if organization_id: + get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) + if workflow := (await session.scalars(get_workflow_query)).first(): if title: workflow.title = title if description: diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 85f7c474..df16dd98 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -82,7 +82,7 @@ class WorkflowService: :return: The created workflow run. """ # Validate the workflow and the organization - workflow = await self.get_workflow(workflow_id=workflow_id) + workflow = await self.get_workflow(workflow_id=workflow_id, organization_id=organization_id) if workflow is None: LOG.error(f"Workflow {workflow_id} not found") raise WorkflowNotFound(workflow_id=workflow_id) @@ -141,10 +141,11 @@ class WorkflowService: self, workflow_run_id: str, api_key: str, + organization_id: str | None = None, ) -> WorkflowRun: """Execute a workflow.""" workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id) - workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id) + workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id) # Set workflow run status to running, create workflow run parameters await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id) @@ -270,11 +271,11 @@ class WorkflowService: organization_id=organization_id, title=title, description=description, - workflow_definition=workflow_definition.model_dump() if workflow_definition else None, + workflow_definition=workflow_definition.model_dump(), ) - async def get_workflow(self, workflow_id: str) -> Workflow: - workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id) + async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow: + workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id) if not workflow: raise WorkflowNotFound(workflow_id) return workflow @@ -282,6 +283,7 @@ class WorkflowService: async def update_workflow( self, workflow_id: str, + organization_id: str | None = None, title: str | None = None, description: str | None = None, workflow_definition: WorkflowDefinition | None = None, @@ -290,6 +292,7 @@ class WorkflowService: workflow_definition.validate() return await app.DATABASE.update_workflow( workflow_id=workflow_id, + organization_id=organization_id, title=title, description=description, workflow_definition=workflow_definition.model_dump() if workflow_definition else None, @@ -449,9 +452,13 @@ class WorkflowService: return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id) async def build_workflow_run_status_response( - self, workflow_id: str, workflow_run_id: str, last_block_result: BlockResult | None, organization_id: str + self, + workflow_id: str, + workflow_run_id: str, + last_block_result: BlockResult | None, + organization_id: str, ) -> WorkflowRunStatusResponse: - workflow = await self.get_workflow(workflow_id=workflow_id) + workflow = await self.get_workflow(workflow_id=workflow_id, organization_id=organization_id) if workflow is None: LOG.error(f"Workflow {workflow_id} not found") raise WorkflowNotFound(workflow_id=workflow_id) @@ -756,6 +763,7 @@ class WorkflowService: workflow_definition = WorkflowDefinition(parameters=parameters.values(), blocks=blocks) workflow = await self.update_workflow( workflow_id=workflow.workflow_id, + organization_id=organization_id, workflow_definition=workflow_definition, ) LOG.info(