From 2e37542218d4ec7922b6b464972a271ebe0f7c49 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sun, 22 Dec 2024 17:49:33 -0800 Subject: [PATCH] add organization_id filter for get_workflow and get_workflow_run (#1422) --- skyvern/forge/agent.py | 5 ++++- skyvern/forge/sdk/db/client.py | 9 ++++---- skyvern/forge/sdk/routes/agent_protocol.py | 5 ++++- skyvern/forge/sdk/routes/streaming.py | 5 ++++- .../forge/sdk/services/observer_service.py | 7 +++++-- skyvern/forge/sdk/workflow/models/block.py | 10 +++++++-- skyvern/forge/sdk/workflow/service.py | 21 ++++++++++++------- 7 files changed, 44 insertions(+), 18 deletions(-) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index bcd4fcef..80411177 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -248,7 +248,10 @@ class ForgeAgent: ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: workflow_run: WorkflowRun | None = None if task.workflow_run_id: - workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=task.workflow_run_id) + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=task.workflow_run_id, + organization_id=organization.organization_id, + ) if workflow_run and workflow_run.status == WorkflowRunStatus.canceled: LOG.info( "Workflow run is canceled, stopping execution inside task", diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index c8bb5c1f..8176a442 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1283,12 +1283,13 @@ class AgentDB: ) return None - async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None: + async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun | None: try: async with self.Session() as session: - if workflow_run := ( - await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)) - ).first(): + get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id) + if organization_id: + get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id) + if workflow_run := (await session.scalars(get_workflow_run_query)).first(): return convert_to_workflow_run(workflow_run) return None except SQLAlchemyError: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 0d42b16d..be3eac35 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -311,7 +311,10 @@ async def cancel_workflow_run( current_org: Organization = Depends(org_auth_service.get_current_org), x_api_key: Annotated[str | None, Header()] = None, ) -> None: - workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id) + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=current_org.organization_id, + ) if not workflow_run: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/skyvern/forge/sdk/routes/streaming.py b/skyvern/forge/sdk/routes/streaming.py index e11a5428..746e815a 100644 --- a/skyvern/forge/sdk/routes/streaming.py +++ b/skyvern/forge/sdk/routes/streaming.py @@ -171,7 +171,10 @@ async def workflow_run_streaming( ) return - workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id) + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) if not workflow_run or workflow_run.organization_id != organization_id: LOG.info( "WofklowRun Streaming: Workflow not found", diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index e6f5b515..ecd0f3bf 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -161,7 +161,7 @@ async def run_observer_cruise( workflow_run_id = observer_cruise.workflow_run_id - workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id) + workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id) if not workflow_run: LOG.error("Workflow run not found", workflow_run_id=workflow_run_id) return None @@ -483,7 +483,10 @@ async def handle_block_result( workflow_run=workflow_run, ) # refresh workflow run model - return await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id=workflow_run_id) + return await app.WORKFLOW_SERVICE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=workflow.organization_id, + ) async def _set_up_workflow_context(workflow_id: str, workflow_run_id: str) -> None: diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 72b8f86e..0dba39be 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -351,8 +351,14 @@ class BaseTaskBlock(Block): # initial value for will_retry is True, so that the loop runs at least once will_retry = True current_running_task: Task | None = None - workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id=workflow_run_id) - workflow = await app.WORKFLOW_SERVICE.get_workflow(workflow_id=workflow_run.workflow_id) + workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + workflow = await app.WORKFLOW_SERVICE.get_workflow( + workflow_id=workflow_run.workflow_id, + organization_id=organization_id, + ) # if the task url is parameterized, we need to get the value from the workflow run context if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url): task_url_parameter_value = workflow_run_context.get_value(self.url) diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index e84f4adc..a851caa9 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -191,7 +191,7 @@ class WorkflowService: ) -> WorkflowRun: """Execute a workflow.""" organization_id = organization.organization_id - workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id) + workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id) # Set workflow run status to running, create workflow run parameters @@ -219,7 +219,8 @@ class WorkflowService: for block_idx, block in enumerate(blocks): try: refreshed_workflow_run = await app.DATABASE.get_workflow_run( - workflow_run_id=workflow_run.workflow_run_id + workflow_run_id=workflow_run.workflow_run_id, + organization_id=organization_id, ) if refreshed_workflow_run and refreshed_workflow_run.status == WorkflowRunStatus.canceled: LOG.info( @@ -358,7 +359,10 @@ class WorkflowService: await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key) return workflow_run - refreshed_workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run.workflow_run_id) + refreshed_workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run.workflow_run_id, + organization_id=organization_id, + ) if refreshed_workflow_run and refreshed_workflow_run.status not in ( WorkflowRunStatus.canceled, WorkflowRunStatus.failed, @@ -570,8 +574,11 @@ class WorkflowService: status=WorkflowRunStatus.canceled, ) - async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: - workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id) + async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun: + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) if not workflow_run: raise WorkflowRunNotFound(workflow_run_id) return workflow_run @@ -734,7 +741,7 @@ class WorkflowService: workflow_run_id: str, organization_id: str, ) -> WorkflowRunStatusResponse: - workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id) + workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) if workflow_run is None: LOG.error(f"Workflow run {workflow_run_id} not found") raise WorkflowRunNotFound(workflow_run_id=workflow_run_id) @@ -756,7 +763,7 @@ class WorkflowService: LOG.error(f"Workflow {workflow_permanent_id} not found") raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id) - workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id) + workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id) screenshot_artifacts = [] screenshot_urls: list[str] | None = None