Add DOWNLOAD_FILE action support for cached scripts (#SKY-7656) (#4569)
This commit is contained in:
@@ -181,6 +181,7 @@ ACTION_MAP = {
|
||||
"wait": "wait",
|
||||
"extract": "extract",
|
||||
"complete": "complete",
|
||||
"download_file": "download_file",
|
||||
}
|
||||
ACTIONS_WITH_XPATH = [
|
||||
"click",
|
||||
@@ -656,6 +657,28 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output:
|
||||
),
|
||||
)
|
||||
)
|
||||
elif method == "download_file":
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("file_name"),
|
||||
value=_value(act.get("file_name", "")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
if act.get("download_url"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("download_url"),
|
||||
value=_value(act["download_url"]),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif method == "extract":
|
||||
args.append(
|
||||
cst.Arg(
|
||||
@@ -779,7 +802,11 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
||||
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))
|
||||
|
||||
for act in actions:
|
||||
if act["action_type"] in [ActionType.COMPLETE, ActionType.TERMINATE, ActionType.NULL_ACTION]:
|
||||
if act["action_type"] in [
|
||||
ActionType.COMPLETE,
|
||||
ActionType.TERMINATE,
|
||||
ActionType.NULL_ACTION,
|
||||
]:
|
||||
continue
|
||||
|
||||
# For extraction blocks, assign extract action results to output variable
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, overload
|
||||
|
||||
@@ -10,7 +11,7 @@ from playwright.async_api import Locator, Page
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.core.script_generations.skyvern_page_ai import SkyvernPageAi
|
||||
from skyvern.forge.sdk.api.files import download_file
|
||||
from skyvern.forge.sdk.api.files import download_file as download_file_from_url
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.library.ai_locator import AILocator
|
||||
from skyvern.webeye.actions import handler_utils
|
||||
@@ -551,7 +552,7 @@ class SkyvernPage(Page):
|
||||
error_to_raise = None
|
||||
if selector and files:
|
||||
try:
|
||||
file_path = await download_file(files)
|
||||
file_path = await download_file_from_url(files)
|
||||
locator = self.page.locator(selector)
|
||||
await locator.set_input_files(file_path, **kwargs)
|
||||
except Exception as e:
|
||||
@@ -586,7 +587,7 @@ class SkyvernPage(Page):
|
||||
if not files:
|
||||
raise ValueError("Parameter 'files' is required but was not provided")
|
||||
|
||||
file_path = await download_file(files)
|
||||
file_path = await download_file_from_url(files)
|
||||
locator = self.page.locator(selector)
|
||||
await locator.set_input_files(file_path, timeout=timeout, **kwargs)
|
||||
return files
|
||||
@@ -732,6 +733,31 @@ class SkyvernPage(Page):
|
||||
async def complete(self, prompt: str | None = None) -> None:
|
||||
"""Stub for complete. Override in subclasses for specific behavior."""
|
||||
|
||||
@action_wrap(ActionType.DOWNLOAD_FILE)
|
||||
async def download_file(
|
||||
self,
|
||||
file_name: str | None = None,
|
||||
download_url: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Download a file from a URL and save it locally during cached script replay.
|
||||
|
||||
Args:
|
||||
file_name: The original file name (for logging/reference). Defaults to UUID if empty.
|
||||
download_url: The URL to download the file from.
|
||||
|
||||
Returns:
|
||||
The local file path where the file was saved.
|
||||
"""
|
||||
if not download_url:
|
||||
raise ValueError("download_url is required for download_file action in cached scripts")
|
||||
|
||||
# Use uuid as fallback for empty file_name, matching handler.py behavior
|
||||
file_name = file_name or str(uuid.uuid4())
|
||||
|
||||
file_path = await download_file_from_url(download_url, filename=file_name)
|
||||
return file_path
|
||||
|
||||
@action_wrap(ActionType.RELOAD_PAGE)
|
||||
async def reload_page(self, **kwargs: Any) -> None:
|
||||
await self.page.reload(**kwargs)
|
||||
|
||||
@@ -233,6 +233,7 @@ async def personalize_action(
|
||||
ActionType.WAIT,
|
||||
ActionType.SOLVE_CAPTCHA,
|
||||
ActionType.NULL_ACTION,
|
||||
ActionType.DOWNLOAD_FILE,
|
||||
]:
|
||||
return [action]
|
||||
elif action.action_type == ActionType.TERMINATE:
|
||||
@@ -246,7 +247,13 @@ async def personalize_action(
|
||||
|
||||
|
||||
def check_for_unsupported_actions(actions: list[Action]) -> None:
|
||||
supported_actions = [ActionType.INPUT_TEXT, ActionType.WAIT, ActionType.CLICK, ActionType.COMPLETE]
|
||||
supported_actions = [
|
||||
ActionType.INPUT_TEXT,
|
||||
ActionType.WAIT,
|
||||
ActionType.CLICK,
|
||||
ActionType.COMPLETE,
|
||||
ActionType.DOWNLOAD_FILE,
|
||||
]
|
||||
supported_actions_with_query = [ActionType.INPUT_TEXT]
|
||||
for action in actions:
|
||||
query = action.intention
|
||||
|
||||
Reference in New Issue
Block a user