Implement upload_file endpoint (#547)

This commit is contained in:
Kerem Yilmaz
2024-07-03 17:54:31 -07:00
committed by GitHub
parent 03a1b6d92c
commit 21b9eea446
3 changed files with 71 additions and 2 deletions

View File

@@ -48,6 +48,7 @@ class Settings(BaseSettings):
# S3 bucket settings # S3 bucket settings
AWS_REGION: str = "us-east-1" AWS_REGION: str = "us-east-1"
AWS_S3_BUCKET_UPLOADS: str = "skyvern-uploads" AWS_S3_BUCKET_UPLOADS: str = "skyvern-uploads"
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB
SKYVERN_TELEMETRY: bool = True SKYVERN_TELEMETRY: bool = True
ANALYTICS_ID: str = "anonymous" ANALYTICS_ID: str = "anonymous"

View File

@@ -1,5 +1,5 @@
from enum import StrEnum from enum import StrEnum
from typing import Any, Callable from typing import IO, Any, Callable
from urllib.parse import urlparse from urllib.parse import urlparse
import aioboto3 import aioboto3
@@ -55,6 +55,17 @@ class AsyncAWSClient:
LOG.exception("S3 upload failed.", uri=uri) LOG.exception("S3 upload failed.", uri=uri)
return None return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file_stream(self, uri: str, file_obj: IO[bytes], client: AioBaseClient = None) -> str | None:
try:
parsed_uri = S3Uri(uri)
await client.upload_fileobj(file_obj, parsed_uri.bucket, parsed_uri.key)
LOG.debug("Upload file stream success", uri=uri)
return uri
except Exception:
LOG.exception("S3 upload stream failed.", uri=uri)
return None
@execute_with_async_client(client_type=AWSClientType.S3) @execute_with_async_client(client_type=AWSClientType.S3)
async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None: async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None:
try: try:
@@ -137,3 +148,6 @@ class S3Uri(object):
@property @property
def uri(self) -> str: def uri(self) -> str:
return self._parsed.geturl() return self._parsed.geturl()
aws_client = AsyncAWSClient()

View File

@@ -1,8 +1,21 @@
import datetime
import uuid
from typing import Annotated, Any from typing import Annotated, Any
import structlog import structlog
import yaml import yaml
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, status from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
Header,
HTTPException,
Query,
Request,
Response,
UploadFile,
status,
)
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from pydantic import BaseModel from pydantic import BaseModel
@@ -10,6 +23,7 @@ from skyvern import analytics
from skyvern.exceptions import StepNotFound from skyvern.exceptions import StepNotFound
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.aws import aws_client
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
@@ -736,3 +750,43 @@ async def update_organization(
max_steps_per_run=org_update.max_steps_per_run, max_steps_per_run=org_update.max_steps_per_run,
max_retries_per_step=org_update.max_retries_per_step, max_retries_per_step=org_update.max_retries_per_step,
) )
async def validate_file_size(file: UploadFile) -> UploadFile:
# Check the file size
if file.size > app.SETTINGS_MANAGER.MAX_UPLOAD_FILE_SIZE:
raise HTTPException(
status_code=413,
detail=f"File size exceeds the maximum allowed size ({app.SETTINGS_MANAGER.MAX_UPLOAD_FILE_SIZE} bytes)",
)
return file
@base_router.post("/upload_file/", include_in_schema=False)
@base_router.post("/upload_file")
async def upload_file(
file: UploadFile = Depends(validate_file_size),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Response:
bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}"
s3_uri = (
f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}"
)
# Stream the file to S3
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file)
if not uploaded_s3_uri:
raise HTTPException(status_code=500, detail="Failed to upload file to S3.")
# Generate a presigned URL for the uploaded file
presigned_urls = await aws_client.create_presigned_urls([uploaded_s3_uri])
if not presigned_urls:
raise HTTPException(status_code=500, detail="Failed to generate presigned URL.")
presigned_url = presigned_urls[0]
return ORJSONResponse(
content={"s3_uri": uploaded_s3_uri, "presigned_url": presigned_url},
status_code=200,
media_type="application/json",
)