add API endpoint and database query for retrieving run artifacts (#2639)
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c35b05925b
commit
d7efb6c33c
@@ -986,6 +986,82 @@ class AgentDB:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifacts_for_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
group_by_type: bool = False,
|
||||
sort_by: str = "created_at",
|
||||
) -> dict[ArtifactType, list[Artifact]] | list[Artifact]:
|
||||
"""Return artifacts associated with a run.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run to get artifacts for
|
||||
organization_id: The ID of the organization that owns the run
|
||||
artifact_types: Optional list of artifact types to filter by
|
||||
group_by_type: If True, returns a dictionary mapping artifact types to lists of artifacts.
|
||||
If False, returns a flat list of artifacts. Defaults to False.
|
||||
sort_by: Field to sort artifacts by. Must be one of: 'created_at', 'step_id', 'task_id'.
|
||||
Defaults to 'created_at'.
|
||||
|
||||
Returns:
|
||||
If group_by_type is True, returns a dictionary mapping artifact types to lists of artifacts.
|
||||
If group_by_type is False, returns a list of artifacts sorted by the specified field.
|
||||
|
||||
Raises:
|
||||
ValueError: If sort_by is not one of the allowed values
|
||||
"""
|
||||
allowed_sort_fields = {"created_at", "step_id", "task_id"}
|
||||
if sort_by not in allowed_sort_fields:
|
||||
raise ValueError(f"sort_by must be one of {allowed_sort_fields}")
|
||||
run = await self.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
return []
|
||||
|
||||
async with self.Session() as session:
|
||||
query = select(ArtifactModel).filter_by(organization_id=organization_id)
|
||||
|
||||
if run.task_run_type in [
|
||||
RunType.task_v1,
|
||||
RunType.openai_cua,
|
||||
RunType.anthropic_cua,
|
||||
]:
|
||||
query = query.filter_by(task_id=run.run_id)
|
||||
elif run.task_run_type == RunType.task_v2:
|
||||
query = query.filter_by(observer_cruise_id=run.run_id)
|
||||
elif run.task_run_type == RunType.workflow_run:
|
||||
query = query.filter_by(workflow_run_id=run.run_id)
|
||||
else:
|
||||
return []
|
||||
|
||||
if artifact_types:
|
||||
query = query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
# Apply sorting
|
||||
if sort_by == "created_at":
|
||||
query = query.order_by(ArtifactModel.created_at)
|
||||
elif sort_by == "step_id":
|
||||
query = query.order_by(ArtifactModel.step_id, ArtifactModel.created_at)
|
||||
elif sort_by == "task_id":
|
||||
query = query.order_by(ArtifactModel.task_id, ArtifactModel.created_at)
|
||||
|
||||
# Execute query and convert to Artifact objects
|
||||
artifacts = [
|
||||
convert_to_artifact(artifact, self.debug_enabled) for artifact in (await session.scalars(query)).all()
|
||||
]
|
||||
|
||||
# Group artifacts by type if requested
|
||||
if group_by_type:
|
||||
result: dict[ArtifactType, list[Artifact]] = {}
|
||||
for artifact in artifacts:
|
||||
if artifact.artifact_type not in result:
|
||||
result[artifact.artifact_type] = []
|
||||
result[artifact.artifact_type].append(artifact)
|
||||
return result
|
||||
|
||||
return artifacts
|
||||
|
||||
async def get_artifact_by_id(
|
||||
self,
|
||||
artifact_id: str,
|
||||
|
||||
@@ -13,7 +13,7 @@ from skyvern.exceptions import MissingBrowserAddressError
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||
from skyvern.forge.sdk.artifact.models import Artifact
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
@@ -721,6 +721,56 @@ async def get_artifact(
|
||||
return artifact
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/runs/{run_id}/artifacts",
|
||||
tags=["Artifacts"],
|
||||
response_model=list[Artifact],
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "artifacts",
|
||||
"x-fern-sdk-method-name": "get_run_artifacts",
|
||||
},
|
||||
description="Get artifacts for a run",
|
||||
summary="Get artifacts for a run",
|
||||
)
|
||||
@base_router.get("/runs/{run_id}/artifacts/", response_model=list[Artifact], include_in_schema=False)
|
||||
async def get_run_artifacts(
|
||||
run_id: str = Path(..., description="The id of the task run or the workflow run."),
|
||||
artifact_type: Annotated[list[ArtifactType] | None, Query()] = None,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
analytics.capture("skyvern-oss-run-artifacts-get")
|
||||
# Get artifacts as a list (not grouped by type)
|
||||
artifacts = await app.DATABASE.get_artifacts_for_run(
|
||||
run_id=run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
artifact_types=artifact_type,
|
||||
group_by_type=False, # This ensures we get a list, not a dict
|
||||
)
|
||||
|
||||
if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS:
|
||||
# Ensure we have a list of artifacts
|
||||
artifacts_list = artifacts if isinstance(artifacts, list) else []
|
||||
|
||||
# Get signed URLs for all artifacts
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts_list)
|
||||
|
||||
if signed_urls and len(signed_urls) == len(artifacts_list):
|
||||
for i, artifact in enumerate(artifacts_list):
|
||||
if hasattr(artifact, "signed_url"):
|
||||
artifact.signed_url = signed_urls[i]
|
||||
elif signed_urls:
|
||||
LOG.warning(
|
||||
"Mismatch between artifacts and signed URLs count",
|
||||
artifacts_count=len(artifacts_list),
|
||||
urls_count=len(signed_urls),
|
||||
run_id=run_id,
|
||||
)
|
||||
else:
|
||||
LOG.warning("Failed to get signed urls for artifacts", run_id=run_id)
|
||||
|
||||
return artifacts
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/runs/{run_id}/retry_webhook",
|
||||
tags=["Agent"],
|
||||
|
||||
Reference in New Issue
Block a user