downloaded files in pbs response (#3487)
This commit is contained in:
@@ -135,6 +135,12 @@ class BaseStorage(ABC):
|
||||
) -> list[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[FileInfo]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save_downloaded_files(self, organization_id: str, run_id: str | None) -> None:
|
||||
pass
|
||||
|
||||
@@ -215,6 +215,11 @@ class LocalStorage(BaseStorage):
|
||||
) -> list[str]:
|
||||
return []
|
||||
|
||||
async def get_shared_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[FileInfo]:
|
||||
return []
|
||||
|
||||
async def list_downloading_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
|
||||
@@ -203,6 +203,36 @@ class S3Storage(BaseStorage):
|
||||
f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
|
||||
async def get_shared_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[FileInfo]:
|
||||
object_keys = await self.list_downloaded_files_in_browser_session(organization_id, browser_session_id)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
# Get metadata (including checksum)
|
||||
metadata = await self.async_client.get_file_metadata(key, log_exception=False)
|
||||
|
||||
# Create FileInfo object
|
||||
filename = os.path.basename(key)
|
||||
checksum = metadata.get("sha256_checksum") if metadata else None
|
||||
|
||||
# Get presigned URL
|
||||
presigned_urls = await self.async_client.create_presigned_urls([key])
|
||||
if not presigned_urls:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
url=presigned_urls[0],
|
||||
checksum=checksum,
|
||||
filename=metadata.get("original_filename", filename) if metadata else filename,
|
||||
)
|
||||
file_infos.append(file_info)
|
||||
|
||||
return file_infos
|
||||
|
||||
async def list_downloading_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from fastapi import Depends, HTTPException, Path
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
@@ -45,7 +47,7 @@ async def create_browser_session(
|
||||
timeout_minutes=browser_session_request.timeout,
|
||||
proxy_location=browser_session_request.proxy_location,
|
||||
)
|
||||
return BrowserSessionResponse.from_browser_session(browser_session)
|
||||
return await BrowserSessionResponse.from_browser_session(browser_session)
|
||||
|
||||
|
||||
@base_router.post(
|
||||
@@ -116,7 +118,7 @@ async def get_browser_session(
|
||||
)
|
||||
if not browser_session:
|
||||
raise HTTPException(status_code=404, detail=f"Browser session {browser_session_id} not found")
|
||||
return BrowserSessionResponse.from_browser_session(browser_session)
|
||||
return await BrowserSessionResponse.from_browser_session(browser_session, app.STORAGE)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -145,4 +147,9 @@ async def get_browser_sessions(
|
||||
"""Get all active browser sessions for the organization"""
|
||||
analytics.capture("skyvern-oss-agent-browser-sessions-get")
|
||||
browser_sessions = await app.PERSISTENT_SESSIONS_MANAGER.get_active_sessions(current_org.organization_id)
|
||||
return [BrowserSessionResponse.from_browser_session(browser_session) for browser_session in browser_sessions]
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
BrowserSessionResponse.from_browser_session(browser_session, app.STORAGE)
|
||||
for browser_session in browser_sessions
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user