Add Checksums to downloaded files for Axis so they can validate it in the webhook (#1848)
This commit is contained in:
@@ -61,6 +61,7 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
@@ -1794,7 +1795,7 @@ class ForgeAgent:
|
||||
recording_url = None
|
||||
browser_console_log_url: str | None = None
|
||||
latest_action_screenshot_urls: list[str] | None = None
|
||||
downloaded_file_urls: list[str] | None = None
|
||||
downloaded_files: list[FileInfo] | None = None
|
||||
|
||||
# get the artifact of the screenshot and get the screenshot_url
|
||||
screenshot_artifact = await app.DATABASE.get_artifact(
|
||||
@@ -1832,7 +1833,7 @@ class ForgeAgent:
|
||||
if task.organization_id:
|
||||
try:
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=task.organization_id, task_id=task.task_id, workflow_run_id=task.workflow_run_id
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -1869,8 +1870,8 @@ class ForgeAgent:
|
||||
action_screenshot_urls=latest_action_screenshot_urls,
|
||||
screenshot_url=screenshot_url,
|
||||
recording_url=recording_url,
|
||||
downloaded_file_urls=downloaded_file_urls,
|
||||
browser_console_log_url=browser_console_log_url,
|
||||
downloaded_files=downloaded_files,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
|
||||
|
||||
@@ -90,10 +90,21 @@ class AsyncAWSClient:
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
|
||||
async def upload_file_from_path(
|
||||
self, uri: str, file_path: str, client: AioBaseClient = None, metadata: dict | None = None
|
||||
) -> None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.upload_file(file_path, parsed_uri.bucket, parsed_uri.key)
|
||||
params: dict[str, Any] = {
|
||||
"Filename": file_path,
|
||||
"Bucket": parsed_uri.bucket,
|
||||
"Key": parsed_uri.key,
|
||||
}
|
||||
|
||||
if metadata:
|
||||
params["ExtraArgs"] = {"Metadata": metadata}
|
||||
|
||||
await client.upload_file(**params)
|
||||
except Exception:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
|
||||
@@ -101,6 +112,8 @@ class AsyncAWSClient:
|
||||
async def download_file(self, uri: str, client: AioBaseClient = None, log_exception: bool = True) -> bytes | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
|
||||
# Get full object including body
|
||||
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return await response["Body"].read()
|
||||
except Exception:
|
||||
@@ -108,6 +121,32 @@ class AsyncAWSClient:
|
||||
LOG.exception("S3 download failed", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def get_file_metadata(
|
||||
self, uri: str, client: AioBaseClient = None, log_exception: bool = True
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves only the metadata of a file without downloading its content.
|
||||
|
||||
Args:
|
||||
uri: The S3 URI of the file
|
||||
client: Optional S3 client to use
|
||||
log_exception: Whether to log exceptions
|
||||
|
||||
Returns:
|
||||
The metadata dictionary or None if the request fails
|
||||
"""
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
|
||||
# Only get object metadata without the body
|
||||
response = await client.head_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return response.get("Metadata", {})
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("S3 metadata retrieval failed", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def create_presigned_urls(self, uris: list[str], client: AioBaseClient = None) -> list[str] | None:
|
||||
presigned_urls = []
|
||||
|
||||
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import ObserverTask, ObserverThought
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
|
||||
@@ -115,5 +116,5 @@ class BaseStorage(ABC):
|
||||
@abstractmethod
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[str]:
|
||||
) -> list[FileInfo]:
|
||||
pass
|
||||
|
||||
@@ -6,11 +6,17 @@ from pathlib import Path
|
||||
import structlog
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir, parse_uri_to_path
|
||||
from skyvern.forge.sdk.api.files import (
|
||||
calculate_sha256_for_file,
|
||||
get_download_dir,
|
||||
get_skyvern_temp_dir,
|
||||
parse_uri_to_path,
|
||||
)
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
|
||||
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import ObserverTask, ObserverThought
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
|
||||
@@ -157,15 +163,18 @@ class LocalStorage(BaseStorage):
|
||||
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[str]:
|
||||
) -> list[FileInfo]:
|
||||
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
|
||||
files: list[str] = []
|
||||
file_infos: list[FileInfo] = []
|
||||
files_and_folders = os.listdir(download_dir)
|
||||
for file_or_folder in files_and_folders:
|
||||
path = os.path.join(download_dir, file_or_folder)
|
||||
if os.path.isfile(path):
|
||||
files.append(f"file://{path}")
|
||||
return files
|
||||
# Calculate checksum for the file
|
||||
checksum = calculate_sha256_for_file(path)
|
||||
file_info = FileInfo(url=f"file://{path}", checksum=checksum, filename=file_or_folder)
|
||||
file_infos.append(file_info)
|
||||
return file_infos
|
||||
|
||||
@staticmethod
|
||||
def _create_directories_if_not_exists(path_including_file_name: Path) -> None:
|
||||
|
||||
@@ -2,10 +2,13 @@ import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import DOWNLOAD_FILE_PREFIX
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.api.files import (
|
||||
calculate_sha256_for_file,
|
||||
create_named_temporary_file,
|
||||
get_download_dir,
|
||||
get_skyvern_temp_dir,
|
||||
@@ -16,9 +19,12 @@ from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityT
|
||||
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import ObserverTask, ObserverThought
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class S3Storage(BaseStorage):
|
||||
def __init__(self, bucket: str | None = None) -> None:
|
||||
@@ -117,21 +123,45 @@ class S3Storage(BaseStorage):
|
||||
fpath = os.path.join(download_dir, file)
|
||||
if os.path.isfile(fpath):
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
|
||||
# TODO: use coroutine to speed up uploading if too many files
|
||||
await self.async_client.upload_file_from_path(uri, fpath)
|
||||
|
||||
# Calculate SHA-256 checksum
|
||||
checksum = calculate_sha256_for_file(fpath)
|
||||
LOG.info("Calculated checksum for file", file=file, checksum=checksum)
|
||||
|
||||
# Upload file with checksum metadata
|
||||
await self.async_client.upload_file_from_path(
|
||||
uri=uri, file_path=fpath, metadata={"sha256_checksum": checksum, "original_filename": file}
|
||||
)
|
||||
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[str]:
|
||||
) -> list[FileInfo]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}"
|
||||
object_keys = await self.async_client.list_files(uri=uri)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
object_uris: list[str] = []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
object_uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{key}"
|
||||
object_uris.append(object_uri)
|
||||
presigned_urils = await self.async_client.create_presigned_urls(object_uris)
|
||||
if presigned_urils is None:
|
||||
return []
|
||||
return presigned_urils
|
||||
|
||||
# Get metadata (including checksum)
|
||||
metadata = await self.async_client.get_file_metadata(object_uri, log_exception=False)
|
||||
|
||||
# Create FileInfo object
|
||||
filename = os.path.basename(key)
|
||||
checksum = metadata.get("sha256_checksum") if metadata else None
|
||||
|
||||
# Get presigned URL
|
||||
presigned_urls = await self.async_client.create_presigned_urls([object_uri])
|
||||
if not presigned_urls:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
url=presigned_urls[0],
|
||||
checksum=checksum,
|
||||
filename=metadata.get("original_filename", filename) if metadata else filename,
|
||||
)
|
||||
file_infos.append(file_info)
|
||||
|
||||
return file_infos
|
||||
|
||||
9
skyvern/forge/sdk/schemas/files.py
Normal file
9
skyvern/forge/sdk/schemas/files.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
"""Information about a downloaded file, including URL and checksum."""
|
||||
|
||||
url: str = Field(..., description="URL to access the file")
|
||||
checksum: str | None = Field(None, description="SHA-256 checksum of the file")
|
||||
filename: str | None = Field(None, description="Original filename")
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled, TaskAlreadyTimeout
|
||||
from skyvern.forge.sdk.core.validators import validate_url
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
|
||||
|
||||
class ProxyLocation(StrEnum):
|
||||
@@ -310,7 +311,7 @@ class Task(TaskBase):
|
||||
screenshot_url: str | None = None,
|
||||
recording_url: str | None = None,
|
||||
browser_console_log_url: str | None = None,
|
||||
downloaded_file_urls: list[str] | None = None,
|
||||
downloaded_files: list[FileInfo] | None = None,
|
||||
failure_reason: str | None = None,
|
||||
) -> TaskResponse:
|
||||
return TaskResponse(
|
||||
@@ -325,7 +326,8 @@ class Task(TaskBase):
|
||||
screenshot_url=screenshot_url,
|
||||
recording_url=recording_url,
|
||||
browser_console_log_url=browser_console_log_url,
|
||||
downloaded_file_urls=downloaded_file_urls,
|
||||
downloaded_files=downloaded_files,
|
||||
downloaded_file_urls=[file.url for file in downloaded_files] if downloaded_files else None,
|
||||
errors=self.errors,
|
||||
max_steps_per_run=self.max_steps_per_run,
|
||||
workflow_run_id=self.workflow_run_id,
|
||||
@@ -343,6 +345,7 @@ class TaskResponse(BaseModel):
|
||||
screenshot_url: str | None = None
|
||||
recording_url: str | None = None
|
||||
browser_console_log_url: str | None = None
|
||||
downloaded_files: list[FileInfo] | None = None
|
||||
downloaded_file_urls: list[str] | None = None
|
||||
failure_reason: str | None = None
|
||||
errors: list[dict[str, Any]] = []
|
||||
@@ -356,16 +359,21 @@ class TaskOutput(BaseModel):
|
||||
extracted_information: list | dict[str, Any] | str | None = None
|
||||
failure_reason: str | None = None
|
||||
errors: list[dict[str, Any]] = []
|
||||
downloaded_file_urls: list[str] | None = None
|
||||
downloaded_files: list[FileInfo] | None = None
|
||||
downloaded_file_urls: list[str] | None = None # For backward compatibility
|
||||
|
||||
@staticmethod
|
||||
def from_task(task: Task, downloaded_file_urls: list[str] | None = None) -> TaskOutput:
|
||||
def from_task(task: Task, downloaded_files: list[FileInfo] | None = None) -> TaskOutput:
|
||||
# For backward compatibility, extract just the URLs from FileInfo objects
|
||||
downloaded_file_urls = [file_info.url for file_info in downloaded_files] if downloaded_files else None
|
||||
|
||||
return TaskOutput(
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
extracted_information=task.extracted_information,
|
||||
failure_reason=task.failure_reason,
|
||||
errors=task.errors,
|
||||
downloaded_files=downloaded_files,
|
||||
downloaded_file_urls=downloaded_file_urls,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core.validators import prepend_scheme_and_validate_url
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import ObserverTaskStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
|
||||
@@ -634,17 +635,18 @@ class BaseTaskBlock(Block):
|
||||
)
|
||||
success = updated_task.status == TaskStatus.completed
|
||||
|
||||
downloaded_file_urls = []
|
||||
downloaded_files: list[FileInfo] = []
|
||||
try:
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=workflow_run.organization_id,
|
||||
task_id=updated_task.task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id)
|
||||
task_output = TaskOutput.from_task(updated_task, downloaded_file_urls)
|
||||
|
||||
task_output = TaskOutput.from_task(updated_task, downloaded_files)
|
||||
output_parameter_value = task_output.model_dump()
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, output_parameter_value)
|
||||
return await self.build_block_result(
|
||||
@@ -693,10 +695,9 @@ class BaseTaskBlock(Block):
|
||||
current_retry += 1
|
||||
will_retry = current_retry <= self.max_retries
|
||||
retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else ""
|
||||
downloaded_file_urls = []
|
||||
try:
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=workflow_run.organization_id,
|
||||
task_id=updated_task.task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
@@ -705,7 +706,7 @@ class BaseTaskBlock(Block):
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id)
|
||||
|
||||
task_output = TaskOutput.from_task(updated_task, downloaded_file_urls)
|
||||
task_output = TaskOutput.from_task(updated_task, downloaded_files)
|
||||
LOG.warning(
|
||||
f"Task failed with status {updated_task.status}{retry_message}",
|
||||
task_id=updated_task.task_id,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, List
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from skyvern.forge.sdk.core.validators import validate_url
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import ObserverTask
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
|
||||
@@ -143,6 +144,7 @@ class WorkflowRunStatusResponse(BaseModel):
|
||||
parameters: dict[str, Any]
|
||||
screenshot_urls: list[str] | None = None
|
||||
recording_url: str | None = None
|
||||
downloaded_files: list[FileInfo] | None = None
|
||||
downloaded_file_urls: list[str] | None = None
|
||||
outputs: dict[str, Any] | None = None
|
||||
total_steps: int | None = None
|
||||
|
||||
@@ -23,6 +23,7 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunTimeline, WorkflowRunTimelineType
|
||||
@@ -1001,14 +1002,17 @@ class WorkflowService:
|
||||
if recording_artifact:
|
||||
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
|
||||
|
||||
downloaded_files: list[FileInfo] | None = None
|
||||
downloaded_file_urls: list[str] | None = None
|
||||
try:
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=workflow_run.organization_id,
|
||||
task_id=None,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
if downloaded_files:
|
||||
downloaded_file_urls = [file_info.url for file_info in downloaded_files]
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning(
|
||||
"Timeout to get downloaded files",
|
||||
@@ -1072,6 +1076,7 @@ class WorkflowService:
|
||||
parameters=parameters_with_value,
|
||||
screenshot_urls=screenshot_urls,
|
||||
recording_url=recording_url,
|
||||
downloaded_files=downloaded_files,
|
||||
downloaded_file_urls=downloaded_file_urls,
|
||||
outputs=outputs,
|
||||
total_steps=total_steps,
|
||||
|
||||
Reference in New Issue
Block a user