Implement upload_file endpoint (#547)
This commit is contained in:
@@ -48,6 +48,7 @@ class Settings(BaseSettings):
|
||||
# S3 bucket settings
|
||||
AWS_REGION: str = "us-east-1"
|
||||
AWS_S3_BUCKET_UPLOADS: str = "skyvern-uploads"
|
||||
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
SKYVERN_TELEMETRY: bool = True
|
||||
ANALYTICS_ID: str = "anonymous"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Callable
|
||||
from typing import IO, Any, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aioboto3
|
||||
@@ -55,6 +55,17 @@ class AsyncAWSClient:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_stream(self, uri: str, file_obj: IO[bytes], client: AioBaseClient = None) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.upload_fileobj(file_obj, parsed_uri.bucket, parsed_uri.key)
|
||||
LOG.debug("Upload file stream success", uri=uri)
|
||||
return uri
|
||||
except Exception:
|
||||
LOG.exception("S3 upload stream failed.", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
|
||||
try:
|
||||
@@ -137,3 +148,6 @@ class S3Uri(object):
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._parsed.geturl()
|
||||
|
||||
|
||||
aws_client = AsyncAWSClient()
|
||||
|
||||
@@ -1,8 +1,21 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Annotated, Any
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
Header,
|
||||
HTTPException,
|
||||
Query,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -10,6 +23,7 @@ from skyvern import analytics
|
||||
from skyvern.exceptions import StepNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.aws import aws_client
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
@@ -736,3 +750,43 @@ async def update_organization(
|
||||
max_steps_per_run=org_update.max_steps_per_run,
|
||||
max_retries_per_step=org_update.max_retries_per_step,
|
||||
)
|
||||
|
||||
|
||||
async def validate_file_size(file: UploadFile) -> UploadFile:
|
||||
# Check the file size
|
||||
if file.size > app.SETTINGS_MANAGER.MAX_UPLOAD_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File size exceeds the maximum allowed size ({app.SETTINGS_MANAGER.MAX_UPLOAD_FILE_SIZE} bytes)",
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
@base_router.post("/upload_file/", include_in_schema=False)
|
||||
@base_router.post("/upload_file")
|
||||
async def upload_file(
|
||||
file: UploadFile = Depends(validate_file_size),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Response:
|
||||
bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS
|
||||
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}"
|
||||
s3_uri = (
|
||||
f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}"
|
||||
)
|
||||
# Stream the file to S3
|
||||
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file)
|
||||
if not uploaded_s3_uri:
|
||||
raise HTTPException(status_code=500, detail="Failed to upload file to S3.")
|
||||
|
||||
# Generate a presigned URL for the uploaded file
|
||||
presigned_urls = await aws_client.create_presigned_urls([uploaded_s3_uri])
|
||||
if not presigned_urls:
|
||||
raise HTTPException(status_code=500, detail="Failed to generate presigned URL.")
|
||||
|
||||
presigned_url = presigned_urls[0]
|
||||
return ORJSONResponse(
|
||||
content={"s3_uri": uploaded_s3_uri, "presigned_url": presigned_url},
|
||||
status_code=200,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user