download with file extension (#1142)

This commit is contained in:
LawyZheng
2024-11-06 10:15:47 +08:00
committed by GitHub
parent 1b9f45b908
commit 087275492f
2 changed files with 65 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ REPO_ROOT_DIR = SKYVERN_DIR.parent
INPUT_TEXT_TIMEOUT = 120000 # 2 minutes
PAGE_CONTENT_TIMEOUT = 300 # 5 mins
BROWSER_CLOSE_TIMEOUT = 180 # 3 minute
BROWSER_DOWNLOAD_TIMEOUT = 600 # 10 minute
# reserved fields for navigation payload
SPECIAL_FIELD_VERIFICATION_CODE = "verification_code"

View File

@@ -6,15 +6,16 @@ import tempfile
import time
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Awaitable, Callable, Protocol
import aiofiles
import structlog
from playwright.async_api import BrowserContext, ConsoleMessage, Error, Page, Playwright
from playwright.async_api import BrowserContext, ConsoleMessage, Download, Error, Page, Playwright
from pydantic import BaseModel, PrivateAttr
from skyvern.config import settings
from skyvern.constants import BROWSER_CLOSE_TIMEOUT, REPO_ROOT_DIR
from skyvern.constants import BROWSER_CLOSE_TIMEOUT, BROWSER_DOWNLOAD_TIMEOUT, REPO_ROOT_DIR
from skyvern.exceptions import (
FailedToNavigateToUrl,
FailedToReloadPage,
@@ -68,6 +69,66 @@ def set_browser_console_log(browser_context: BrowserContext, browser_artifacts:
browser_context.on("console", browser_console_log)
def set_download_file_listener(browser_context: BrowserContext, **kwargs: Any) -> None:
async def listen_to_download(download: Download) -> None:
try:
workflow_run_id = kwargs.get("workflow_run_id")
task_id = kwargs.get("task_id")
async with asyncio.timeout(BROWSER_DOWNLOAD_TIMEOUT):
file_path = await download.path()
if file_path.suffix:
return
LOG.info(
"No file extensions, going to add file extension automatically",
workflow_run_id=workflow_run_id,
task_id=task_id,
suggested_filename=download.suggested_filename,
url=download.url,
)
suffix = Path(download.suggested_filename).suffix
if suffix:
LOG.info(
"Add extension according to suggested filename",
workflow_run_id=workflow_run_id,
task_id=task_id,
filepath=str(file_path) + suffix,
)
file_path.rename(str(file_path) + suffix)
return
suffix = Path(download.url).suffix
if suffix:
LOG.info(
"Add extension according to download url",
workflow_run_id=workflow_run_id,
task_id=task_id,
filepath=str(file_path) + suffix,
)
file_path.rename(str(file_path) + suffix)
return
# TODO: maybe should try to parse it from URL response
except asyncio.TimeoutError:
LOG.error(
"timeout to download file, going to cancel the download",
workflow_run_id=workflow_run_id,
task_id=task_id,
)
await download.cancel()
except Exception:
LOG.exception(
"Failed to add file extension name to downloaded file",
workflow_run_id=workflow_run_id,
task_id=task_id,
)
def listen_to_new_page(page: Page) -> None:
page.on("download", listen_to_download)
browser_context.on("page", listen_to_new_page)
class BrowserContextCreator(Protocol):
def __call__(
self, playwright: Playwright, **kwargs: dict[str, Any]
@@ -145,6 +206,7 @@ class BrowserContextFactory:
raise UnknownBrowserType(browser_type)
browser_context, browser_artifacts, cleanup_func = await creator(playwright, **kwargs)
set_browser_console_log(browser_context=browser_context, browser_artifacts=browser_artifacts)
set_download_file_listener(browser_context=browser_context, **kwargs)
return browser_context, browser_artifacts, cleanup_func
except Exception as e:
if browser_context is not None: