From d7efb6c33cfcc7f95f1a701e42ff0d54b743efbf Mon Sep 17 00:00:00 2001 From: Prakash Maheshwaran <73785492+Prakashmaheshwaran@users.noreply.github.com> Date: Mon, 16 Jun 2025 20:00:19 -0400 Subject: [PATCH] add API endpoint and database query for retrieving run artifacts (#2639) Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- skyvern/forge/sdk/db/client.py | 76 ++++++++++++++++++++++ skyvern/forge/sdk/routes/agent_protocol.py | 52 ++++++++++++++- 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index a3697c6a..c384bc51 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -986,6 +986,82 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise + async def get_artifacts_for_run( + self, + run_id: str, + organization_id: str, + artifact_types: list[ArtifactType] | None = None, + group_by_type: bool = False, + sort_by: str = "created_at", + ) -> dict[ArtifactType, list[Artifact]] | list[Artifact]: + """Return artifacts associated with a run. + + Args: + run_id: The ID of the run to get artifacts for + organization_id: The ID of the organization that owns the run + artifact_types: Optional list of artifact types to filter by + group_by_type: If True, returns a dictionary mapping artifact types to lists of artifacts. + If False, returns a flat list of artifacts. Defaults to False. + sort_by: Field to sort artifacts by. Must be one of: 'created_at', 'step_id', 'task_id'. + Defaults to 'created_at'. + + Returns: + If group_by_type is True, returns a dictionary mapping artifact types to lists of artifacts. + If group_by_type is False, returns a list of artifacts sorted by the specified field. + + Raises: + ValueError: If sort_by is not one of the allowed values + """ + allowed_sort_fields = {"created_at", "step_id", "task_id"} + if sort_by not in allowed_sort_fields: + raise ValueError(f"sort_by must be one of {allowed_sort_fields}") + run = await self.get_run(run_id, organization_id=organization_id) + if not run: + return [] + + async with self.Session() as session: + query = select(ArtifactModel).filter_by(organization_id=organization_id) + + if run.task_run_type in [ + RunType.task_v1, + RunType.openai_cua, + RunType.anthropic_cua, + ]: + query = query.filter_by(task_id=run.run_id) + elif run.task_run_type == RunType.task_v2: + query = query.filter_by(observer_cruise_id=run.run_id) + elif run.task_run_type == RunType.workflow_run: + query = query.filter_by(workflow_run_id=run.run_id) + else: + return [] + + if artifact_types: + query = query.filter(ArtifactModel.artifact_type.in_(artifact_types)) + + # Apply sorting + if sort_by == "created_at": + query = query.order_by(ArtifactModel.created_at) + elif sort_by == "step_id": + query = query.order_by(ArtifactModel.step_id, ArtifactModel.created_at) + elif sort_by == "task_id": + query = query.order_by(ArtifactModel.task_id, ArtifactModel.created_at) + + # Execute query and convert to Artifact objects + artifacts = [ + convert_to_artifact(artifact, self.debug_enabled) for artifact in (await session.scalars(query)).all() + ] + + # Group artifacts by type if requested + if group_by_type: + result: dict[ArtifactType, list[Artifact]] = {} + for artifact in artifacts: + if artifact.artifact_type not in result: + result[artifact.artifact_type] = [] + result[artifact.artifact_type].append(artifact) + return result + + return artifacts + async def get_artifact_by_id( self, artifact_id: str, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 99a92282..50eacc1a 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -13,7 +13,7 @@ from skyvern.exceptions import MissingBrowserAddressError from skyvern.forge import app from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError -from skyvern.forge.sdk.artifact.models import Artifact +from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory from skyvern.forge.sdk.core.security import generate_skyvern_signature @@ -721,6 +721,56 @@ async def get_artifact( return artifact +@base_router.get( + "/runs/{run_id}/artifacts", + tags=["Artifacts"], + response_model=list[Artifact], + openapi_extra={ + "x-fern-sdk-group-name": "artifacts", + "x-fern-sdk-method-name": "get_run_artifacts", + }, + description="Get artifacts for a run", + summary="Get artifacts for a run", +) +@base_router.get("/runs/{run_id}/artifacts/", response_model=list[Artifact], include_in_schema=False) +async def get_run_artifacts( + run_id: str = Path(..., description="The id of the task run or the workflow run."), + artifact_type: Annotated[list[ArtifactType] | None, Query()] = None, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> Response: + analytics.capture("skyvern-oss-run-artifacts-get") + # Get artifacts as a list (not grouped by type) + artifacts = await app.DATABASE.get_artifacts_for_run( + run_id=run_id, + organization_id=current_org.organization_id, + artifact_types=artifact_type, + group_by_type=False, # This ensures we get a list, not a dict + ) + + if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS: + # Ensure we have a list of artifacts + artifacts_list = artifacts if isinstance(artifacts, list) else [] + + # Get signed URLs for all artifacts + signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts_list) + + if signed_urls and len(signed_urls) == len(artifacts_list): + for i, artifact in enumerate(artifacts_list): + if hasattr(artifact, "signed_url"): + artifact.signed_url = signed_urls[i] + elif signed_urls: + LOG.warning( + "Mismatch between artifacts and signed URLs count", + artifacts_count=len(artifacts_list), + urls_count=len(signed_urls), + run_id=run_id, + ) + else: + LOG.warning("Failed to get signed urls for artifacts", run_id=run_id) + + return artifacts + + @base_router.post( "/runs/{run_id}/retry_webhook", tags=["Agent"],