From d0c87e1caf6cdb02c3d41b1037eec835f03d7d48 Mon Sep 17 00:00:00 2001 From: Prakash Maheshwaran <73785492+Prakashmaheshwaran@users.noreply.github.com> Date: Thu, 19 Jun 2025 00:24:11 -0400 Subject: [PATCH] handle null organization_id in artifact queries for backward compatibility (#2748) --- skyvern/forge/sdk/db/client.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 6e32c262..ccf17839 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from typing import Any, List, Sequence import structlog -from sqlalchemy import and_, delete, distinct, func, pool, select, tuple_, update +from sqlalchemy import and_, delete, distinct, func, or_, pool, select, tuple_, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine @@ -1089,7 +1089,7 @@ class AgentDB: async def get_artifacts_by_entity_id( self, *, - organization_id: str, + organization_id: str | None, artifact_type: ArtifactType | None = None, task_id: str | None = None, step_id: str | None = None, @@ -1100,6 +1100,7 @@ class AgentDB: ) -> list[Artifact]: try: async with self.Session() as session: + # Build base query query = select(ArtifactModel) if artifact_type is not None: @@ -1116,14 +1117,17 @@ class AgentDB: query = query.filter_by(observer_thought_id=thought_id) if task_v2_id is not None: query = query.filter_by(observer_cruise_id=task_v2_id) + # Handle backward compatibility where old artifact rows were stored with organization_id NULL if organization_id is not None: - query = query.filter_by(organization_id=organization_id) + query = query.filter( + or_(ArtifactModel.organization_id == organization_id, ArtifactModel.organization_id.is_(None)) + ) query = query.order_by(ArtifactModel.created_at.desc()) - if artifacts := (await session.scalars(query)).all(): - return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts] - else: - return [] + + artifacts = (await session.scalars(query)).all() + LOG.debug("Artifacts fetched", count=len(artifacts)) + return [convert_to_artifact(a, self.debug_enabled) for a in artifacts] except SQLAlchemyError: LOG.error("SQLAlchemyError", exc_info=True) raise