add organization_id filter for get_workflow and get_workflow_run (#1422)
This commit is contained in:
@@ -248,7 +248,10 @@ class ForgeAgent:
|
|||||||
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
|
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
|
||||||
workflow_run: WorkflowRun | None = None
|
workflow_run: WorkflowRun | None = None
|
||||||
if task.workflow_run_id:
|
if task.workflow_run_id:
|
||||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=task.workflow_run_id)
|
workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
organization_id=organization.organization_id,
|
||||||
|
)
|
||||||
if workflow_run and workflow_run.status == WorkflowRunStatus.canceled:
|
if workflow_run and workflow_run.status == WorkflowRunStatus.canceled:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Workflow run is canceled, stopping execution inside task",
|
"Workflow run is canceled, stopping execution inside task",
|
||||||
|
|||||||
@@ -1283,12 +1283,13 @@ class AgentDB:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
|
async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun | None:
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
if workflow_run := (
|
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
|
||||||
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
|
if organization_id:
|
||||||
).first():
|
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
|
||||||
|
if workflow_run := (await session.scalars(get_workflow_run_query)).first():
|
||||||
return convert_to_workflow_run(workflow_run)
|
return convert_to_workflow_run(workflow_run)
|
||||||
return None
|
return None
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
|
|||||||
@@ -311,7 +311,10 @@ async def cancel_workflow_run(
|
|||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
x_api_key: Annotated[str | None, Header()] = None,
|
x_api_key: Annotated[str | None, Header()] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id)
|
workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
|||||||
@@ -171,7 +171,10 @@ async def workflow_run_streaming(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id)
|
workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
if not workflow_run or workflow_run.organization_id != organization_id:
|
if not workflow_run or workflow_run.organization_id != organization_id:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"WofklowRun Streaming: Workflow not found",
|
"WofklowRun Streaming: Workflow not found",
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ async def run_observer_cruise(
|
|||||||
|
|
||||||
workflow_run_id = observer_cruise.workflow_run_id
|
workflow_run_id = observer_cruise.workflow_run_id
|
||||||
|
|
||||||
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id)
|
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
LOG.error("Workflow run not found", workflow_run_id=workflow_run_id)
|
LOG.error("Workflow run not found", workflow_run_id=workflow_run_id)
|
||||||
return None
|
return None
|
||||||
@@ -483,7 +483,10 @@ async def handle_block_result(
|
|||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
)
|
)
|
||||||
# refresh workflow run model
|
# refresh workflow run model
|
||||||
return await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id=workflow_run_id)
|
return await app.WORKFLOW_SERVICE.get_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=workflow.organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _set_up_workflow_context(workflow_id: str, workflow_run_id: str) -> None:
|
async def _set_up_workflow_context(workflow_id: str, workflow_run_id: str) -> None:
|
||||||
|
|||||||
@@ -351,8 +351,14 @@ class BaseTaskBlock(Block):
|
|||||||
# initial value for will_retry is True, so that the loop runs at least once
|
# initial value for will_retry is True, so that the loop runs at least once
|
||||||
will_retry = True
|
will_retry = True
|
||||||
current_running_task: Task | None = None
|
current_running_task: Task | None = None
|
||||||
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id=workflow_run_id)
|
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(
|
||||||
workflow = await app.WORKFLOW_SERVICE.get_workflow(workflow_id=workflow_run.workflow_id)
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
workflow = await app.WORKFLOW_SERVICE.get_workflow(
|
||||||
|
workflow_id=workflow_run.workflow_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
# if the task url is parameterized, we need to get the value from the workflow run context
|
# if the task url is parameterized, we need to get the value from the workflow run context
|
||||||
if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url):
|
if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url):
|
||||||
task_url_parameter_value = workflow_run_context.get_value(self.url)
|
task_url_parameter_value = workflow_run_context.get_value(self.url)
|
||||||
|
|||||||
@@ -191,7 +191,7 @@ class WorkflowService:
|
|||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
"""Execute a workflow."""
|
"""Execute a workflow."""
|
||||||
organization_id = organization.organization_id
|
organization_id = organization.organization_id
|
||||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
|
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||||
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id)
|
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id)
|
||||||
|
|
||||||
# Set workflow run status to running, create workflow run parameters
|
# Set workflow run status to running, create workflow run parameters
|
||||||
@@ -219,7 +219,8 @@ class WorkflowService:
|
|||||||
for block_idx, block in enumerate(blocks):
|
for block_idx, block in enumerate(blocks):
|
||||||
try:
|
try:
|
||||||
refreshed_workflow_run = await app.DATABASE.get_workflow_run(
|
refreshed_workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
workflow_run_id=workflow_run.workflow_run_id
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
if refreshed_workflow_run and refreshed_workflow_run.status == WorkflowRunStatus.canceled:
|
if refreshed_workflow_run and refreshed_workflow_run.status == WorkflowRunStatus.canceled:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
@@ -358,7 +359,10 @@ class WorkflowService:
|
|||||||
await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key)
|
await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key)
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
refreshed_workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
|
refreshed_workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
if refreshed_workflow_run and refreshed_workflow_run.status not in (
|
if refreshed_workflow_run and refreshed_workflow_run.status not in (
|
||||||
WorkflowRunStatus.canceled,
|
WorkflowRunStatus.canceled,
|
||||||
WorkflowRunStatus.failed,
|
WorkflowRunStatus.failed,
|
||||||
@@ -570,8 +574,11 @@ class WorkflowService:
|
|||||||
status=WorkflowRunStatus.canceled,
|
status=WorkflowRunStatus.canceled,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun:
|
||||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id)
|
workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise WorkflowRunNotFound(workflow_run_id)
|
raise WorkflowRunNotFound(workflow_run_id)
|
||||||
return workflow_run
|
return workflow_run
|
||||||
@@ -734,7 +741,7 @@ class WorkflowService:
|
|||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
) -> WorkflowRunStatusResponse:
|
) -> WorkflowRunStatusResponse:
|
||||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
|
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||||
if workflow_run is None:
|
if workflow_run is None:
|
||||||
LOG.error(f"Workflow run {workflow_run_id} not found")
|
LOG.error(f"Workflow run {workflow_run_id} not found")
|
||||||
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
|
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
|
||||||
@@ -756,7 +763,7 @@ class WorkflowService:
|
|||||||
LOG.error(f"Workflow {workflow_permanent_id} not found")
|
LOG.error(f"Workflow {workflow_permanent_id} not found")
|
||||||
raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id)
|
raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id)
|
||||||
|
|
||||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
|
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||||
screenshot_artifacts = []
|
screenshot_artifacts = []
|
||||||
screenshot_urls: list[str] | None = None
|
screenshot_urls: list[str] | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user