Add Checksums to downloaded files for Axis so they can validate it in the webhook (#1848)

This commit is contained in:
Shuchang Zheng
2025-02-26 17:19:05 -08:00
committed by GitHub
parent c73ad6ed68
commit 995d9461b5
10 changed files with 136 additions and 31 deletions

View File

@@ -61,6 +61,7 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
@@ -1794,7 +1795,7 @@ class ForgeAgent:
recording_url = None
browser_console_log_url: str | None = None
latest_action_screenshot_urls: list[str] | None = None
downloaded_file_urls: list[str] | None = None
downloaded_files: list[FileInfo] | None = None
# get the artifact of the screenshot and get the screenshot_url
screenshot_artifact = await app.DATABASE.get_artifact(
@@ -1832,7 +1833,7 @@ class ForgeAgent:
if task.organization_id:
try:
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
downloaded_files = await app.STORAGE.get_downloaded_files(
organization_id=task.organization_id, task_id=task.task_id, workflow_run_id=task.workflow_run_id
)
except asyncio.TimeoutError:
@@ -1869,8 +1870,8 @@ class ForgeAgent:
action_screenshot_urls=latest_action_screenshot_urls,
screenshot_url=screenshot_url,
recording_url=recording_url,
downloaded_file_urls=downloaded_file_urls,
browser_console_log_url=browser_console_log_url,
downloaded_files=downloaded_files,
failure_reason=failure_reason,
)

View File

@@ -90,10 +90,21 @@ class AsyncAWSClient:
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
async def upload_file_from_path(
self, uri: str, file_path: str, client: AioBaseClient = None, metadata: dict | None = None
) -> None:
try:
parsed_uri = S3Uri(uri)
await client.upload_file(file_path, parsed_uri.bucket, parsed_uri.key)
params: dict[str, Any] = {
"Filename": file_path,
"Bucket": parsed_uri.bucket,
"Key": parsed_uri.key,
}
if metadata:
params["ExtraArgs"] = {"Metadata": metadata}
await client.upload_file(**params)
except Exception:
LOG.exception("S3 upload failed.", uri=uri)
@@ -101,6 +112,8 @@ class AsyncAWSClient:
async def download_file(self, uri: str, client: AioBaseClient = None, log_exception: bool = True) -> bytes | None:
try:
parsed_uri = S3Uri(uri)
# Get full object including body
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
return await response["Body"].read()
except Exception:
@@ -108,6 +121,32 @@ class AsyncAWSClient:
LOG.exception("S3 download failed", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def get_file_metadata(
self, uri: str, client: AioBaseClient = None, log_exception: bool = True
) -> dict | None:
"""
Retrieves only the metadata of a file without downloading its content.
Args:
uri: The S3 URI of the file
client: Optional S3 client to use
log_exception: Whether to log exceptions
Returns:
The metadata dictionary or None if the request fails
"""
try:
parsed_uri = S3Uri(uri)
# Only get object metadata without the body
response = await client.head_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
return response.get("Metadata", {})
except Exception:
if log_exception:
LOG.exception("S3 metadata retrieval failed", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def create_presigned_urls(self, uris: list[str], client: AioBaseClient = None) -> list[str] | None:
presigned_urls = []

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.task_v2 import ObserverTask, ObserverThought
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
@@ -115,5 +116,5 @@ class BaseStorage(ABC):
@abstractmethod
async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> list[str]:
) -> list[FileInfo]:
pass

View File

@@ -6,11 +6,17 @@ from pathlib import Path
import structlog
from skyvern.config import settings
from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir, parse_uri_to_path
from skyvern.forge.sdk.api.files import (
calculate_sha256_for_file,
get_download_dir,
get_skyvern_temp_dir,
parse_uri_to_path,
)
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.task_v2 import ObserverTask, ObserverThought
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
@@ -157,15 +163,18 @@ class LocalStorage(BaseStorage):
async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> list[str]:
) -> list[FileInfo]:
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
files: list[str] = []
file_infos: list[FileInfo] = []
files_and_folders = os.listdir(download_dir)
for file_or_folder in files_and_folders:
path = os.path.join(download_dir, file_or_folder)
if os.path.isfile(path):
files.append(f"file://{path}")
return files
# Calculate checksum for the file
checksum = calculate_sha256_for_file(path)
file_info = FileInfo(url=f"file://{path}", checksum=checksum, filename=file_or_folder)
file_infos.append(file_info)
return file_infos
@staticmethod
def _create_directories_if_not_exists(path_including_file_name: Path) -> None:

View File

@@ -2,10 +2,13 @@ import os
import shutil
from datetime import datetime
import structlog
from skyvern.config import settings
from skyvern.constants import DOWNLOAD_FILE_PREFIX
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.api.files import (
calculate_sha256_for_file,
create_named_temporary_file,
get_download_dir,
get_skyvern_temp_dir,
@@ -16,9 +19,12 @@ from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityT
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.task_v2 import ObserverTask, ObserverThought
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
LOG = structlog.get_logger()
class S3Storage(BaseStorage):
def __init__(self, bucket: str | None = None) -> None:
@@ -117,21 +123,45 @@ class S3Storage(BaseStorage):
fpath = os.path.join(download_dir, file)
if os.path.isfile(fpath):
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
# TODO: use coroutine to speed up uploading if too many files
await self.async_client.upload_file_from_path(uri, fpath)
# Calculate SHA-256 checksum
checksum = calculate_sha256_for_file(fpath)
LOG.info("Calculated checksum for file", file=file, checksum=checksum)
# Upload file with checksum metadata
await self.async_client.upload_file_from_path(
uri=uri, file_path=fpath, metadata={"sha256_checksum": checksum, "original_filename": file}
)
async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> list[str]:
) -> list[FileInfo]:
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}"
object_keys = await self.async_client.list_files(uri=uri)
if len(object_keys) == 0:
return []
object_uris: list[str] = []
file_infos: list[FileInfo] = []
for key in object_keys:
object_uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{key}"
object_uris.append(object_uri)
presigned_urils = await self.async_client.create_presigned_urls(object_uris)
if presigned_urils is None:
return []
return presigned_urils
# Get metadata (including checksum)
metadata = await self.async_client.get_file_metadata(object_uri, log_exception=False)
# Create FileInfo object
filename = os.path.basename(key)
checksum = metadata.get("sha256_checksum") if metadata else None
# Get presigned URL
presigned_urls = await self.async_client.create_presigned_urls([object_uri])
if not presigned_urls:
continue
file_info = FileInfo(
url=presigned_urls[0],
checksum=checksum,
filename=metadata.get("original_filename", filename) if metadata else filename,
)
file_infos.append(file_info)
return file_infos

View File

@@ -0,0 +1,9 @@
from pydantic import BaseModel, Field
class FileInfo(BaseModel):
"""Information about a downloaded file, including URL and checksum."""
url: str = Field(..., description="URL to access the file")
checksum: str | None = Field(None, description="SHA-256 checksum of the file")
filename: str | None = Field(None, description="Original filename")

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, field_validator
from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled, TaskAlreadyTimeout
from skyvern.forge.sdk.core.validators import validate_url
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.schemas.files import FileInfo
class ProxyLocation(StrEnum):
@@ -310,7 +311,7 @@ class Task(TaskBase):
screenshot_url: str | None = None,
recording_url: str | None = None,
browser_console_log_url: str | None = None,
downloaded_file_urls: list[str] | None = None,
downloaded_files: list[FileInfo] | None = None,
failure_reason: str | None = None,
) -> TaskResponse:
return TaskResponse(
@@ -325,7 +326,8 @@ class Task(TaskBase):
screenshot_url=screenshot_url,
recording_url=recording_url,
browser_console_log_url=browser_console_log_url,
downloaded_file_urls=downloaded_file_urls,
downloaded_files=downloaded_files,
downloaded_file_urls=[file.url for file in downloaded_files] if downloaded_files else None,
errors=self.errors,
max_steps_per_run=self.max_steps_per_run,
workflow_run_id=self.workflow_run_id,
@@ -343,6 +345,7 @@ class TaskResponse(BaseModel):
screenshot_url: str | None = None
recording_url: str | None = None
browser_console_log_url: str | None = None
downloaded_files: list[FileInfo] | None = None
downloaded_file_urls: list[str] | None = None
failure_reason: str | None = None
errors: list[dict[str, Any]] = []
@@ -356,16 +359,21 @@ class TaskOutput(BaseModel):
extracted_information: list | dict[str, Any] | str | None = None
failure_reason: str | None = None
errors: list[dict[str, Any]] = []
downloaded_file_urls: list[str] | None = None
downloaded_files: list[FileInfo] | None = None
downloaded_file_urls: list[str] | None = None # For backward compatibility
@staticmethod
def from_task(task: Task, downloaded_file_urls: list[str] | None = None) -> TaskOutput:
def from_task(task: Task, downloaded_files: list[FileInfo] | None = None) -> TaskOutput:
# For backward compatibility, extract just the URLs from FileInfo objects
downloaded_file_urls = [file_info.url for file_info in downloaded_files] if downloaded_files else None
return TaskOutput(
task_id=task.task_id,
status=task.status,
extracted_information=task.extracted_information,
failure_reason=task.failure_reason,
errors=task.errors,
downloaded_files=downloaded_files,
downloaded_file_urls=downloaded_file_urls,
)

View File

@@ -49,6 +49,7 @@ from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core.validators import prepend_scheme_and_validate_url
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.task_v2 import ObserverTaskStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
@@ -634,17 +635,18 @@ class BaseTaskBlock(Block):
)
success = updated_task.status == TaskStatus.completed
downloaded_file_urls = []
downloaded_files: list[FileInfo] = []
try:
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
downloaded_files = await app.STORAGE.get_downloaded_files(
organization_id=workflow_run.organization_id,
task_id=updated_task.task_id,
workflow_run_id=workflow_run_id,
)
except asyncio.TimeoutError:
LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id)
task_output = TaskOutput.from_task(updated_task, downloaded_file_urls)
task_output = TaskOutput.from_task(updated_task, downloaded_files)
output_parameter_value = task_output.model_dump()
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, output_parameter_value)
return await self.build_block_result(
@@ -693,10 +695,9 @@ class BaseTaskBlock(Block):
current_retry += 1
will_retry = current_retry <= self.max_retries
retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else ""
downloaded_file_urls = []
try:
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
downloaded_files = await app.STORAGE.get_downloaded_files(
organization_id=workflow_run.organization_id,
task_id=updated_task.task_id,
workflow_run_id=workflow_run_id,
@@ -705,7 +706,7 @@ class BaseTaskBlock(Block):
except asyncio.TimeoutError:
LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id)
task_output = TaskOutput.from_task(updated_task, downloaded_file_urls)
task_output = TaskOutput.from_task(updated_task, downloaded_files)
LOG.warning(
f"Task failed with status {updated_task.status}{retry_message}",
task_id=updated_task.task_id,

View File

@@ -5,6 +5,7 @@ from typing import Any, List
from pydantic import BaseModel, field_validator
from skyvern.forge.sdk.core.validators import validate_url
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.task_v2 import ObserverTask
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
@@ -143,6 +144,7 @@ class WorkflowRunStatusResponse(BaseModel):
parameters: dict[str, Any]
screenshot_urls: list[str] | None = None
recording_url: str | None = None
downloaded_files: list[FileInfo] | None = None
downloaded_file_urls: list[str] | None = None
outputs: dict[str, Any] | None = None
total_steps: int | None = None

View File

@@ -23,6 +23,7 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunTimeline, WorkflowRunTimelineType
@@ -1001,14 +1002,17 @@ class WorkflowService:
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
downloaded_files: list[FileInfo] | None = None
downloaded_file_urls: list[str] | None = None
try:
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
downloaded_files = await app.STORAGE.get_downloaded_files(
organization_id=workflow_run.organization_id,
task_id=None,
workflow_run_id=workflow_run.workflow_run_id,
)
if downloaded_files:
downloaded_file_urls = [file_info.url for file_info in downloaded_files]
except asyncio.TimeoutError:
LOG.warning(
"Timeout to get downloaded files",
@@ -1072,6 +1076,7 @@ class WorkflowService:
parameters=parameters_with_value,
screenshot_urls=screenshot_urls,
recording_url=recording_url,
downloaded_files=downloaded_files,
downloaded_file_urls=downloaded_file_urls,
outputs=outputs,
total_steps=total_steps,