Implement upload_file endpoint (#547)

This commit is contained in:
Kerem Yilmaz
2024-07-03 17:54:31 -07:00
committed by GitHub
parent 03a1b6d92c
commit 21b9eea446
3 changed files with 71 additions and 2 deletions

View File

@@ -48,6 +48,7 @@ class Settings(BaseSettings):
# S3 bucket settings
AWS_REGION: str = "us-east-1"
AWS_S3_BUCKET_UPLOADS: str = "skyvern-uploads"
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB
SKYVERN_TELEMETRY: bool = True
ANALYTICS_ID: str = "anonymous"

View File

@@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Any, Callable
from typing import IO, Any, Callable
from urllib.parse import urlparse
import aioboto3
@@ -55,6 +55,17 @@ class AsyncAWSClient:
LOG.exception("S3 upload failed.", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file_stream(self, uri: str, file_obj: IO[bytes], client: AioBaseClient = None) -> str | None:
try:
parsed_uri = S3Uri(uri)
await client.upload_fileobj(file_obj, parsed_uri.bucket, parsed_uri.key)
LOG.debug("Upload file stream success", uri=uri)
return uri
except Exception:
LOG.exception("S3 upload stream failed.", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
try:
@@ -137,3 +148,6 @@ class S3Uri(object):
@property
def uri(self) -> str:
return self._parsed.geturl()
aws_client = AsyncAWSClient()

View File

@@ -1,8 +1,21 @@
import datetime
import uuid
from typing import Annotated, Any
import structlog
import yaml
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, status
from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
Header,
HTTPException,
Query,
Request,
Response,
UploadFile,
status,
)
from fastapi.responses import ORJSONResponse
from pydantic import BaseModel
@@ -10,6 +23,7 @@ from skyvern import analytics
from skyvern.exceptions import StepNotFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.aws import aws_client
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.core import skyvern_context
@@ -736,3 +750,43 @@ async def update_organization(
max_steps_per_run=org_update.max_steps_per_run,
max_retries_per_step=org_update.max_retries_per_step,
)
async def validate_file_size(file: UploadFile) -> UploadFile:
# Check the file size
if file.size > app.SETTINGS_MANAGER.MAX_UPLOAD_FILE_SIZE:
raise HTTPException(
status_code=413,
detail=f"File size exceeds the maximum allowed size ({app.SETTINGS_MANAGER.MAX_UPLOAD_FILE_SIZE} bytes)",
)
return file
@base_router.post("/upload_file/", include_in_schema=False)
@base_router.post("/upload_file")
async def upload_file(
file: UploadFile = Depends(validate_file_size),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}"
s3_uri = (
f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}"
)
# Stream the file to S3
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file)
if not uploaded_s3_uri:
raise HTTPException(status_code=500, detail="Failed to upload file to S3.")
# Generate a presigned URL for the uploaded file
presigned_urls = await aws_client.create_presigned_urls([uploaded_s3_uri])
if not presigned_urls:
raise HTTPException(status_code=500, detail="Failed to generate presigned URL.")
presigned_url = presigned_urls[0]
return ORJSONResponse(
content={"s3_uri": uploaded_s3_uri, "presigned_url": presigned_url},
status_code=200,
media_type="application/json",
)