Remove setup.sh in favor of skyvern CLI (#4737)

This commit is contained in:
Shuchang Zheng
2026-02-12 20:43:27 -08:00
committed by GitHub
parent 08d3b04d14
commit 155c07f8be
77 changed files with 12358 additions and 10 deletions

0
tests/unit/__init__.py Normal file
View File

30
tests/unit/conftest.py Normal file
View File

@@ -0,0 +1,30 @@
# -- begin speed up unit tests
import pytest
from tests.unit.force_stub_app import start_forge_stub_app
# NOTE(jdo): uncomment below to run tests faster, if you're targetting smth
# that does not need the full app context
# import sys
# from unittest.mock import MagicMock
# mock_modules = [
# "skyvern.forge.app",
# "skyvern.library",
# "skyvern.core.script_generations.skyvern_page",
# "skyvern.core.script_generations.run_initializer",
# "skyvern.core.script_generations.workflow_wrappers",
# "skyvern.services.script_service",
# ]
# for module in mock_modules:
# sys.modules[module] = MagicMock()
# -- end speed up unit tests
@pytest.fixture(scope="module", autouse=True)
def setup_forge_stub_app():
start_forge_stub_app()
yield

View File

@@ -0,0 +1,60 @@
from unittest.mock import AsyncMock
from skyvern.forge import set_force_app_instance
from skyvern.forge.forge_app import ForgeApp
def create_forge_stub_app() -> ForgeApp:
class _LazyNamespace:
def __getattr__(self, name):
value = AsyncMock()
setattr(self, name, value)
return value
fake_app_module = ForgeApp()
fake_app_module.DATABASE = _LazyNamespace()
fake_app_module.WORKFLOW_CONTEXT_MANAGER = _LazyNamespace()
fake_app_module.WORKFLOW_SERVICE = _LazyNamespace()
fake_app_module.BROWSER_MANAGER = _LazyNamespace()
fake_app_module.PERSISTENT_SESSIONS_MANAGER = _LazyNamespace()
fake_app_module.ARTIFACT_MANAGER = _LazyNamespace()
fake_app_module.AGENT_FUNCTION = _LazyNamespace()
fake_app_module.AGENT_FUNCTION.validate_block_execution = AsyncMock()
fake_app_module.AGENT_FUNCTION.validate_code_block = AsyncMock()
fake_app_module.agent = _LazyNamespace()
fake_app_module.DATABASE.update_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.create_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.create_or_update_workflow_run_output_parameter = AsyncMock()
fake_app_module.DATABASE.get_last_task_for_workflow_run = AsyncMock()
fake_app_module.DATABASE.get_workflow_run = AsyncMock()
fake_app_module.DATABASE.get_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.get_task = AsyncMock()
fake_app_module.DATABASE.update_task = AsyncMock()
fake_app_module.DATABASE.update_task_v2 = AsyncMock()
fake_app_module.DATABASE.get_organization = AsyncMock()
fake_app_module.DATABASE.get_workflow = AsyncMock()
fake_app_module.DATABASE.create_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.update_workflow_run = AsyncMock()
fake_app_module.DATABASE.create_or_update_workflow_run_output_parameter = AsyncMock()
fake_app_module.DATABASE.update_workflow_run_block = AsyncMock()
fake_app_module.LLM_API_HANDLER = AsyncMock()
fake_app_module.SECONDARY_LLM_API_HANDLER = AsyncMock()
fake_app_module.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
fake_app_module.CUSTOM_SELECT_AGENT_LLM_API_HANDLER = AsyncMock()
fake_app_module.NORMAL_SELECT_AGENT_LLM_API_HANDLER = AsyncMock()
fake_app_module.SELECT_AGENT_LLM_API_HANDLER = AsyncMock()
fake_app_module.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock()
fake_app_module.SINGLE_INPUT_AGENT_LLM_API_HANDLER = AsyncMock()
fake_app_module.EXTRACTION_LLM_API_HANDLER = AsyncMock()
fake_app_module.CHECK_USER_GOAL_LLM_API_HANDLER = AsyncMock()
fake_app_module.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
fake_app_module.EXPERIMENTATION_PROVIDER = _LazyNamespace()
fake_app_module.STORAGE = _LazyNamespace()
return fake_app_module
def start_forge_stub_app() -> ForgeApp:
force_app_instance = create_forge_stub_app()
set_force_app_instance(force_app_instance)
return force_app_instance

300
tests/unit/helpers.py Normal file
View File

@@ -0,0 +1,300 @@
from __future__ import annotations
import json
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from types import SimpleNamespace
from typing import Any, Iterator, Sequence
from unittest.mock import AsyncMock, MagicMock
from pytest import MonkeyPatch # type: ignore[import-not-found]
from skyvern.forge import app
from skyvern.forge.agent import ForgeAgent
from skyvern.forge.sdk.api.llm import api_handler_factory
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.models import LLMRouterConfig, LLMRouterModelConfig
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
class FakeLLMResponse:
def __init__(self, model: str) -> None:
self.model = model
self._content = '{"actions": []}'
self.choices = [
SimpleNamespace(
message=SimpleNamespace(
content=self._content,
)
)
]
self.usage = SimpleNamespace(
prompt_tokens=0,
completion_tokens=0,
completion_tokens_details=SimpleNamespace(reasoning_tokens=0),
prompt_tokens_details=SimpleNamespace(cached_tokens=0),
cache_read_input_tokens=0,
)
def model_dump_json(self, indent: int = 2) -> str:
return json.dumps(
{
"model": self.model,
"choices": [
{"message": {"content": self._content}},
],
},
indent=indent,
)
class DummyLogger:
def __init__(self) -> None:
self.events: list[tuple[str, dict[str, Any]]] = []
def info(self, event: str, **kwargs: dict[str, Any]) -> None:
self.events.append((event, kwargs))
def warning(self, *args, **kwargs) -> None: # pragma: no cover
pass
def exception(self, *args, **kwargs) -> None: # pragma: no cover
pass
def debug(self, *args, **kwargs) -> None: # pragma: no cover
pass
@dataclass
class RouterTestContext:
llm_key: str
router_config: LLMRouterConfig
logger: DummyLogger
@contextmanager
def router_test_context(
monkeypatch: MonkeyPatch,
*,
llm_key: str,
primary_group: str,
fallback_group: str,
routing_strategy: str = "simple-shuffle",
) -> Iterator[RouterTestContext]:
router_config = LLMRouterConfig(
model_name="test-router",
required_env_vars=[],
supports_vision=False,
add_assistant_prefix=False,
model_list=[
LLMRouterModelConfig(model_name=primary_group, litellm_params={"model": primary_group}),
LLMRouterModelConfig(model_name=fallback_group, litellm_params={"model": fallback_group}),
],
redis_host="localhost",
redis_port=6379,
redis_password="",
main_model_group=primary_group,
fallback_model_group=fallback_group,
routing_strategy=routing_strategy,
num_retries=0,
disable_cooldowns=True,
temperature=None,
)
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
LLMConfigRegistry.register_config(llm_key, router_config)
logger = DummyLogger()
monkeypatch.setattr(api_handler_factory, "LOG", logger)
async def fake_llm_messages_builder(prompt: str, screenshots, add_assistant_prefix: bool) -> list[dict[str, str]]:
return [{"role": "user", "content": prompt}]
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", fake_llm_messages_builder)
monkeypatch.setattr(api_handler_factory.skyvern_context, "current", lambda: None)
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda completion_response: 0.0)
try:
yield RouterTestContext(llm_key=llm_key, router_config=router_config, logger=logger)
finally:
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
def make_organization(now: datetime) -> Organization:
return Organization(
organization_id="org-123",
organization_name="Org",
webhook_callback_url=None,
max_steps_per_run=None,
max_retries_per_step=None,
domain=None,
bw_organization_id=None,
bw_collection_ids=None,
created_at=now,
modified_at=now,
)
def make_task(now: datetime, organization: Organization, **overrides: Any) -> Task:
base: dict[str, Any] = {
"title": "Task",
"url": "https://example.com",
"webhook_callback_url": None,
"webhook_failure_reason": None,
"totp_verification_url": None,
"totp_identifier": None,
"navigation_goal": "Find the quote",
"data_extraction_goal": "Extract the quote",
"navigation_payload": None,
"error_code_mapping": None,
"proxy_location": None,
"extracted_information_schema": None,
"extra_http_headers": None,
"complete_criterion": None,
"terminate_criterion": None,
"task_type": TaskType.general,
"application": None,
"include_action_history_in_verification": False,
"max_screenshot_scrolls": None,
"browser_address": None,
"download_timeout": None,
"created_at": now,
"modified_at": now,
"task_id": "task-123",
"status": TaskStatus.running,
"extracted_information": None,
"failure_reason": None,
"organization_id": organization.organization_id,
"workflow_run_id": None,
"workflow_permanent_id": None,
"browser_session_id": None,
"order": 0,
"retry": 0,
"max_steps_per_run": None,
"errors": [],
"model": None,
"queued_at": now,
"started_at": now,
"finished_at": None,
}
base.update(overrides)
return Task(**base)
def make_step(
now: datetime,
task: Task,
*,
step_id: str,
status: StepStatus,
order: int,
output,
is_last: bool = False,
retry_index: int = 0,
organization_id: str | None = None,
**overrides: Any,
) -> Step:
base: dict[str, Any] = {
"created_at": now,
"modified_at": now,
"task_id": task.task_id,
"step_id": step_id,
"status": status,
"output": output,
"order": order,
"is_last": is_last,
"retry_index": retry_index,
"organization_id": organization_id or task.organization_id,
}
base.update(overrides)
return Step(**base)
@dataclass
class ParallelVerificationMocks:
create_step: AsyncMock
get_task_steps: AsyncMock
sleep: AsyncMock
check_user_goal_complete: AsyncMock
handle_action: AsyncMock
create_extract_action: AsyncMock | None
speculate_next_step_plan: AsyncMock
persist_speculative_metadata: AsyncMock
cancel_speculative_step: AsyncMock
record_artifacts_after_action: AsyncMock
update_step: AsyncMock
update_task: AsyncMock
def setup_parallel_verification_mocks(
agent: ForgeAgent,
*,
step: Step,
task: Task,
monkeypatch: MonkeyPatch,
next_step: Step | None,
complete_action,
handle_action_responses: Sequence[Any],
extract_action: Any | None = None,
) -> ParallelVerificationMocks:
create_step_mock = AsyncMock(return_value=next_step)
monkeypatch.setattr(app.DATABASE, "create_step", create_step_mock)
get_task_steps_mock = AsyncMock(return_value=[step])
monkeypatch.setattr(app.DATABASE, "get_task_steps", get_task_steps_mock)
sleep_mock = AsyncMock(return_value=None)
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", sleep_mock)
check_user_goal_complete_mock = AsyncMock(return_value=complete_action)
monkeypatch.setattr(agent, "check_user_goal_complete", check_user_goal_complete_mock)
handle_action_mock = AsyncMock(side_effect=handle_action_responses)
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", handle_action_mock)
speculate_mock = AsyncMock(return_value=None)
monkeypatch.setattr(agent, "_speculate_next_step_plan", speculate_mock)
persist_mock = AsyncMock()
monkeypatch.setattr(agent, "_persist_speculative_metadata_for_discarded_plan", persist_mock)
cancel_mock = AsyncMock()
monkeypatch.setattr(agent, "_cancel_speculative_step", cancel_mock)
record_artifacts_mock = AsyncMock()
monkeypatch.setattr(agent, "record_artifacts_after_action", record_artifacts_mock)
update_step_mock = AsyncMock(return_value=step)
monkeypatch.setattr(agent, "update_step", update_step_mock)
update_task_mock = AsyncMock(return_value=task)
monkeypatch.setattr(agent, "update_task", update_task_mock)
if extract_action is not None:
create_extract_action_mock = AsyncMock(return_value=extract_action)
monkeypatch.setattr(agent, "create_extract_action", create_extract_action_mock)
else:
create_extract_action_mock = None
return ParallelVerificationMocks(
create_step=create_step_mock,
get_task_steps=get_task_steps_mock,
sleep=sleep_mock,
check_user_goal_complete=check_user_goal_complete_mock,
handle_action=handle_action_mock,
create_extract_action=create_extract_action_mock,
speculate_next_step_plan=speculate_mock,
persist_speculative_metadata=persist_mock,
cancel_speculative_step=cancel_mock,
record_artifacts_after_action=record_artifacts_mock,
update_step=update_step_mock,
update_task=update_task_mock,
)
def make_browser_state() -> tuple[MagicMock, MagicMock, MagicMock]:
return MagicMock(), MagicMock(), MagicMock()

View File

View File

@@ -0,0 +1,94 @@
"""
Just an example unit test for now. Will expand later.
"""
import typing as t
from skyvern.services.browser_recording.service import Processor
from skyvern.services.browser_recording.types import (
ExfiltratedConsoleEvent,
)
ORG_ID = "org_123"
PBS_ID = "pbs_123"
WP_ID = "wpid_123"
def make_console_event(
params: dict[str, t.Any],
timestamp: float,
) -> ExfiltratedConsoleEvent:
default_params = {
"url": "https://example.com",
"activeElement": {
"tagName": "BUTTON",
},
"window": {
"height": 800,
"width": 1200,
"scrollX": 0,
"scrollY": 0,
},
"mousePosition": {"xp": 0.5, "yp": 0.5},
}
params = {**default_params, **params}
return ExfiltratedConsoleEvent(
kind="exfiltrated-event",
source="console",
event_name="user-interaction",
params=params,
timestamp=timestamp,
)
def make_mouseenter_event(
target: dict[str, t.Any],
timestamp: float,
) -> ExfiltratedConsoleEvent:
params: dict[str, t.Any] = {
"type": "mouseenter",
"target": target,
"timestamp": timestamp,
}
return make_console_event(
params=params,
timestamp=timestamp,
)
def make_mouseleave_event(
target: dict[str, t.Any],
timestamp: float,
) -> ExfiltratedConsoleEvent:
params: dict[str, t.Any] = {
"type": "mouseleave",
"target": target,
"timestamp": timestamp,
}
return make_console_event(
params=params,
timestamp=timestamp,
)
def test_hover() -> None:
target = dict(id="button-1", skyId="sky-123", text=["Click me"])
event1 = make_mouseenter_event(
target=target,
timestamp=1000.0,
)
event2 = make_mouseleave_event(
target=target,
timestamp=4000.0,
)
processor = Processor(PBS_ID, ORG_ID, WP_ID)
actions = processor.events_to_actions([event1, event2])
assert len(actions) == 1

115
tests/unit/test_actions.py Normal file
View File

@@ -0,0 +1,115 @@
from unittest.mock import MagicMock
import pytest
from pydantic import ValidationError
from skyvern.webeye.actions.actions import Action, KeypressAction, NullAction, WebAction
from skyvern.webeye.actions.parse_actions import parse_action
def _mock_scraped_page() -> MagicMock:
page = MagicMock()
page.id_to_element_hash = {}
page.id_to_element_dict = {}
return page
def test_action_parse__no_element_id() -> None:
action_no_element_id = {
"action_type": "click",
}
action = Action.model_validate(action_no_element_id)
assert action.action_type == "click"
assert action.element_id is None
def test_action_parse__with_element_id() -> None:
action_no_element_id_str = {
"action_type": "click",
"element_id": "element_id",
}
action = Action.model_validate(action_no_element_id_str)
assert action.action_type == "click"
assert action.element_id == "element_id"
action_no_element_id_int = {
"action_type": "click",
"element_id": 1,
}
action = Action.model_validate(action_no_element_id_int)
assert action.action_type == "click"
assert action.element_id == "1"
def test_web_action_parse__no_element_id() -> None:
action_no_element_id = {
"action_type": "click",
}
with pytest.raises(ValidationError):
WebAction.model_validate(action_no_element_id)
def test_web_action_parse__with_element_id() -> None:
action_no_element_id_str = {
"action_type": "click",
"element_id": "element_id",
}
action = WebAction.model_validate(action_no_element_id_str)
assert action.action_type == "click"
assert action.element_id == "element_id"
action_no_element_id_int = {
"action_type": "click",
"element_id": 1,
}
action = WebAction.model_validate(action_no_element_id_int)
assert action.action_type == "click"
assert action.element_id == "1"
@pytest.mark.parametrize("key", ["Enter", "Tab", "Escape", "ArrowDown", "ArrowUp"])
def test_parse_keypress_valid_keys(key: str) -> None:
action = parse_action(
action={"action_type": "KEYPRESS", "key": key, "reasoning": "test"},
scraped_page=_mock_scraped_page(),
)
assert isinstance(action, KeypressAction)
assert action.keys == [key]
assert action.element_id is None
assert action.skyvern_element_hash is None
assert action.skyvern_element_data is None
def test_parse_keypress_invalid_key_returns_null_action() -> None:
action = parse_action(
action={"action_type": "KEYPRESS", "key": "Delete", "reasoning": "test"},
scraped_page=_mock_scraped_page(),
)
assert isinstance(action, NullAction)
def test_parse_keypress_backward_compat_press_enter() -> None:
action = parse_action(
action={"action_type": "PRESS_ENTER", "key": "Enter", "reasoning": "test"},
scraped_page=_mock_scraped_page(),
)
assert isinstance(action, KeypressAction)
assert action.keys == ["Enter"]
def test_parse_keypress_keys_list() -> None:
action = parse_action(
action={"action_type": "KEYPRESS", "keys": ["Enter"], "reasoning": "test"},
scraped_page=_mock_scraped_page(),
)
assert isinstance(action, KeypressAction)
assert action.keys == ["Enter"]
def test_parse_keypress_no_key_defaults_to_enter() -> None:
action = parse_action(
action={"action_type": "KEYPRESS", "reasoning": "test"},
scraped_page=_mock_scraped_page(),
)
assert isinstance(action, KeypressAction)
assert action.keys == ["Enter"]

View File

@@ -0,0 +1,165 @@
"""
Tests for ai_click behavior when LLM returns empty actions.
This tests the fix for SKY-7577 where cached click actions were succeeding
even when the target element didn't exist on the page.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.core.script_generations.real_skyvern_page_ai import RealSkyvernPageAi
@pytest.fixture
def mock_page():
"""Create a mock Playwright page."""
page = MagicMock()
page.url = "https://example.com"
mock_locator = MagicMock()
mock_locator.click = AsyncMock()
page.locator = MagicMock(return_value=mock_locator)
return page
@pytest.fixture
def mock_scraped_page():
"""Create a mock ScrapedPage that properly supports async methods."""
scraped_page = MagicMock()
scraped_page.build_element_tree = MagicMock(return_value="<element_tree>")
# The generate_scraped_page method is async and returns self
scraped_page.generate_scraped_page = AsyncMock(return_value=scraped_page)
return scraped_page
@pytest.fixture
def mock_context():
"""Create a mock skyvern context."""
context = MagicMock()
context.organization_id = "org_123"
context.task_id = "task_123"
context.step_id = "step_123"
context.prompt = "Test prompt"
context.tz_info = None
return context
@pytest.fixture
def mock_app():
"""Create a mock app with SINGLE_CLICK_AGENT_LLM_API_HANDLER."""
mock = MagicMock()
mock.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(return_value={"actions": []})
mock.DATABASE = MagicMock()
mock.DATABASE.get_step = AsyncMock(return_value=MagicMock())
return mock
class TestAiClickEmptyActions:
"""Test that ai_click properly fails when LLM returns no actions."""
@pytest.mark.asyncio
async def test_ai_click_raises_when_llm_returns_empty_actions_no_selector(
self, mock_page, mock_scraped_page, mock_context, mock_app
):
"""
When the LLM returns no actions (element doesn't exist on page) and
there's no selector to fall back to, ai_click should raise an exception.
"""
real_skyvern_page_ai = RealSkyvernPageAi(mock_scraped_page, mock_page)
mock_app.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(return_value={"actions": []})
with (
patch.object(real_skyvern_page_ai, "_refresh_scraped_page", new_callable=AsyncMock),
patch(
"skyvern.core.script_generations.real_skyvern_page_ai.skyvern_context.ensure_context",
return_value=mock_context,
),
patch("skyvern.core.script_generations.real_skyvern_page_ai.app", mock_app),
patch(
"skyvern.core.script_generations.real_skyvern_page_ai.prompt_engine.load_prompt",
return_value="mock_prompt",
),
):
with pytest.raises(Exception) as exc_info:
await real_skyvern_page_ai.ai_click(
selector=None, # No fallback selector
intention="Click the download button",
)
# Should raise because no actions and no fallback
assert "AI click failed" in str(exc_info.value) or "AI could not find" in str(exc_info.value)
@pytest.mark.asyncio
async def test_ai_click_raises_when_llm_call_fails_no_selector(
self, mock_page, mock_scraped_page, mock_context, mock_app
):
"""
When AI fails (exception) and there's no selector to fall back to,
ai_click should raise an exception.
"""
real_skyvern_page_ai = RealSkyvernPageAi(mock_scraped_page, mock_page)
mock_app.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(side_effect=Exception("LLM error"))
with (
patch.object(real_skyvern_page_ai, "_refresh_scraped_page", new_callable=AsyncMock),
patch(
"skyvern.core.script_generations.real_skyvern_page_ai.skyvern_context.ensure_context",
return_value=mock_context,
),
patch("skyvern.core.script_generations.real_skyvern_page_ai.app", mock_app),
patch(
"skyvern.core.script_generations.real_skyvern_page_ai.prompt_engine.load_prompt",
return_value="mock_prompt",
),
):
with pytest.raises(Exception) as exc_info:
await real_skyvern_page_ai.ai_click(
selector=None, # No fallback selector
intention="Click the download button",
)
assert "AI click failed" in str(exc_info.value)
@pytest.mark.asyncio
async def test_ai_click_falls_back_to_selector_when_llm_returns_empty(
self, mock_page, mock_scraped_page, mock_context, mock_app
):
"""
When AI returns empty actions but there IS a selector to fall back to,
ai_click should use the selector and succeed.
"""
# Set up the locator mock properly with AsyncMock for click
mock_locator = MagicMock()
mock_locator.click = AsyncMock()
mock_page.locator = MagicMock(return_value=mock_locator)
real_skyvern_page_ai = RealSkyvernPageAi(mock_scraped_page, mock_page)
mock_app.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(return_value={"actions": []})
with (
patch.object(real_skyvern_page_ai, "_refresh_scraped_page", new_callable=AsyncMock),
patch(
"skyvern.core.script_generations.real_skyvern_page_ai.skyvern_context.ensure_context",
return_value=mock_context,
),
patch("skyvern.core.script_generations.real_skyvern_page_ai.app", mock_app),
patch(
"skyvern.core.script_generations.real_skyvern_page_ai.prompt_engine.load_prompt",
return_value="mock_prompt",
),
):
# Should NOT raise because we have a fallback selector
result = await real_skyvern_page_ai.ai_click(
selector="xpath=//button[@id='download']", # Has fallback
intention="Click the download button",
)
# Should have used the fallback selector
mock_page.locator.assert_called_once_with("xpath=//button[@id='download']")
assert result == "xpath=//button[@id='download']"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,91 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import skyvern.forge as forge_module
import skyvern.forge.sdk.core.skyvern_context as skyvern_context_module
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.api.llm.models import LLMConfig
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
# Replace the forge app holder with a MagicMock so test imports don't require a fully
# initialised ForgeApp instance.
forge_module.set_force_app_instance(MagicMock())
forge_module.app.EXPERIMENTATION_PROVIDER = MagicMock()
@pytest.mark.asyncio
async def test_cached_content_removed_from_non_extract_prompts() -> None:
mock_config = MagicMock(spec=LLMConfig)
mock_config.model_name = "gemini-2.5-pro"
mock_config.litellm_params = {}
mock_config.supports_vision = False
mock_config.add_assistant_prefix = False
mock_config.max_completion_tokens = 100
mock_config.max_tokens = None
mock_config.temperature = 0.0
mock_config.reasoning_effort = None
mock_config.disable_cooldowns = True
mock_response = MagicMock()
mock_response.model_dump_json.return_value = "{}"
mock_response.choices = [MagicMock(message=MagicMock(content="test"))]
mock_response.usage = MagicMock(
prompt_tokens=10,
completion_tokens=10,
completion_tokens_details=None,
prompt_tokens_details=None,
cache_read_input_tokens=0,
)
# Ensure app dependencies referenced inside the handler resolve to async mocks.
forge_module.app.ARTIFACT_MANAGER = MagicMock()
forge_module.app.DATABASE = MagicMock()
forge_module.app.DATABASE.update_step = AsyncMock()
forge_module.app.DATABASE.update_thought = AsyncMock()
with (
patch("skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", return_value=mock_config),
patch(
"skyvern.forge.sdk.api.llm.api_handler_factory.litellm.acompletion",
new_callable=AsyncMock,
) as mock_acompletion,
patch(
"skyvern.forge.sdk.api.llm.api_handler_factory.litellm.completion_cost",
return_value=0.001,
),
patch(
"skyvern.forge.sdk.api.llm.api_handler_factory.llm_messages_builder", new_callable=AsyncMock
) as mock_builder,
patch("skyvern.forge.sdk.api.llm.api_handler_factory.parse_api_response", return_value={}),
):
mock_builder.return_value = [{"role": "user", "content": "test"}]
mock_acompletion.return_value = mock_response
handler = LLMAPIHandlerFactory.get_llm_api_handler("gemini-2.5-pro")
context = SkyvernContext()
context.cached_static_prompt = "some static prompt"
context.use_prompt_caching = True
context.vertex_cache_name = "projects/123/locations/global/cachedContents/demo"
token = skyvern_context_module._context.set(context)
try:
# Extract actions attaches cached_content.
await handler(prompt="test", prompt_name="extract-actions")
args, kwargs = mock_acompletion.call_args
assert kwargs.get("cached_content") == "projects/123/locations/global/cachedContents/demo"
# Non-extract prompt should not include cached_content.
mock_acompletion.reset_mock()
await handler(prompt="test", prompt_name="check-user-goal")
_, kwargs = mock_acompletion.call_args
assert "cached_content" not in kwargs
# Even if user supplied cached_content manually, it must be stripped.
mock_acompletion.reset_mock()
await handler(prompt="test", prompt_name="check-user-goal", parameters={"cached_content": "leaked"})
_, kwargs = mock_acompletion.call_args
assert "cached_content" not in kwargs
finally:
skyvern_context_module._context.reset(token)

View File

@@ -0,0 +1,226 @@
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest # type: ignore[import-not-found]
from skyvern.forge.sdk.api.llm import api_handler_factory
from skyvern.forge.sdk.api.llm.api_handler_factory import (
EXTRACT_ACTION_PROMPT_NAME,
LLMAPIHandlerFactory,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig
from tests.unit.helpers import FakeLLMResponse
@pytest.mark.asyncio
async def test_cached_content_not_added_for_non_gemini(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that cached_content is NOT added to non-Gemini models like GPT-4."""
# Setup context with caching enabled
context = MagicMock()
context.vertex_cache_name = "projects/123/locations/us-central1/cachedContents/456"
context.use_prompt_caching = True
context.cached_static_prompt = "some static prompt"
context.hashed_href_map = {}
# Setup non-Gemini config
llm_config = LLMConfig(
model_name="gpt-4",
required_env_vars=[],
supports_vision=True,
add_assistant_prefix=False,
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
)
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
monkeypatch.setattr(
api_handler_factory, "llm_messages_builder", AsyncMock(return_value=[{"role": "user", "content": "test"}])
)
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
# Mock litellm.acompletion to capture the parameters
completion_params = {}
async def mock_acompletion(*args, **kwargs):
completion_params.update(kwargs)
return FakeLLMResponse("gpt-4")
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
# Get handler and call it
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
await handler(prompt="test prompt", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
# Verify cached_content was NOT passed
assert "cached_content" not in completion_params
assert completion_params["model"] == "gpt-4"
@pytest.mark.asyncio
async def test_cached_content_added_for_gemini(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that cached_content IS added for Gemini models."""
# Setup context with caching enabled
context = MagicMock()
context.vertex_cache_name = "projects/123/locations/us-central1/cachedContents/456"
context.use_prompt_caching = True
context.cached_static_prompt = "some static prompt"
context.hashed_href_map = {}
# Setup Gemini config
llm_config = LLMConfig(
model_name="gemini-1.5-pro",
required_env_vars=[],
supports_vision=True,
add_assistant_prefix=False,
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
)
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
monkeypatch.setattr(
api_handler_factory, "llm_messages_builder", AsyncMock(return_value=[{"role": "user", "content": "test"}])
)
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
# Mock litellm.acompletion to capture the parameters
completion_params = {}
async def mock_acompletion(*args, **kwargs):
completion_params.update(kwargs)
return FakeLLMResponse("gemini-1.5-pro")
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
# Get handler and call it
handler = LLMAPIHandlerFactory.get_llm_api_handler("gemini-1.5-pro")
await handler(prompt="test prompt", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
# Verify cached_content WAS passed
assert "cached_content" in completion_params
assert completion_params["cached_content"] == "projects/123/locations/us-central1/cachedContents/456"
assert completion_params["model"] == "gemini-1.5-pro"
@pytest.mark.asyncio
async def test_openai_caching_not_injected_for_check_user_goal(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that OpenAI context caching system message is NOT injected for check-user-goal prompts.
This is a regression test for a bug where the extract-action-static.j2 prompt was being
injected as a system message for ALL prompts on OpenAI models, causing the LLM to return
CLICK actions when running check-user-goal (which should only return COMPLETE/TERMINATE).
"""
# Setup context with caching enabled (simulating state after extract-action ran)
context = MagicMock()
context.vertex_cache_name = None
context.use_prompt_caching = True
context.cached_static_prompt = "This is the extract-action-static prompt content"
context.hashed_href_map = {}
# Setup OpenAI config (GPT-4)
llm_config = LLMConfig(
model_name="gpt-4",
required_env_vars=[],
supports_vision=True,
add_assistant_prefix=False,
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
)
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
# Capture messages passed to LLM
captured_messages: list = []
async def mock_llm_messages_builder(prompt, screenshots, add_assistant_prefix):
return [{"role": "user", "content": prompt}]
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", mock_llm_messages_builder)
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
async def mock_acompletion(*args, **kwargs):
captured_messages.extend(kwargs.get("messages", []))
return FakeLLMResponse("gpt-4")
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
# Get handler and call it with check-user-goal prompt (NOT extract-actions)
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
await handler(prompt="check-user-goal prompt content", prompt_name="check-user-goal")
# Verify the cached_static_prompt was NOT injected as a system message
# There should only be the user message, no system message with the cached content
system_messages = [m for m in captured_messages if m.get("role") == "system"]
assert len(system_messages) == 0, (
f"Expected no system messages with cached content for check-user-goal, but found: {system_messages}"
)
@pytest.mark.asyncio
async def test_openai_caching_injected_for_extract_actions(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that OpenAI context caching system message IS injected for extract-actions prompts."""
# Setup context with caching enabled
context = MagicMock()
context.vertex_cache_name = None
context.use_prompt_caching = True
context.cached_static_prompt = "This is the extract-action-static prompt content"
context.hashed_href_map = {}
# Setup OpenAI config (GPT-4)
llm_config = LLMConfig(
model_name="gpt-4",
required_env_vars=[],
supports_vision=True,
add_assistant_prefix=False,
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
)
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
# Capture messages passed to LLM
captured_messages: list = []
async def mock_llm_messages_builder(prompt, screenshots, add_assistant_prefix):
return [{"role": "user", "content": prompt}]
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", mock_llm_messages_builder)
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
async def mock_acompletion(*args, **kwargs):
captured_messages.extend(kwargs.get("messages", []))
return FakeLLMResponse("gpt-4")
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
# Get handler and call it with extract-actions prompt
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
await handler(prompt="extract-actions prompt content", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
# Verify the cached_static_prompt WAS injected as a system message
system_messages = [m for m in captured_messages if m.get("role") == "system"]
assert len(system_messages) == 1, (
f"Expected 1 system message with cached content for extract-actions, "
f"but found {len(system_messages)}: {system_messages}"
)
# Check the system message contains the cached content
system_content = system_messages[0].get("content", [])
assert any(part.get("text") == "This is the extract-action-static prompt content" for part in system_content), (
f"System message should contain cached_static_prompt, got: {system_content}"
)

View File

@@ -0,0 +1,524 @@
"""Tests for the location auto-completion fast-path optimisation.
When the user types an address into a location field and exactly one autocomplete
suggestion appears, we skip the LLM call and click the suggestion directly.
"""
from __future__ import annotations
import copy
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.constants import SKYVERN_ID_ATTR
from skyvern.forge.sdk.models import StepStatus
from skyvern.webeye.actions.actions import InputOrSelectContext
from skyvern.webeye.actions.handler import (
AutoCompletionResult,
choose_auto_completion_dropdown,
input_or_auto_complete_input,
)
from skyvern.webeye.actions.responses import ActionSuccess
from tests.unit.helpers import make_organization, make_step, make_task
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_NOW = datetime.now(UTC)
_ORG = make_organization(_NOW)
_TASK = make_task(_NOW, _ORG, navigation_payload={"address": "123 Main St"})
_STEP = make_step(_NOW, _TASK, step_id="stp-1", status=StepStatus.created, order=0, output=None)
SINGLE_ELEMENT = [{"id": "AAAA", "tag": "div", "text": "123 Main St, Springfield, IL"}]
MULTI_ELEMENTS = [
{"id": "AAAA", "tag": "div", "text": "123 Main St, Springfield, IL"},
{"id": "AAAB", "tag": "div", "text": "123 Main St, Springfield, MO"},
]
def _make_location_context(**overrides: object) -> InputOrSelectContext:
defaults = {
"field": "Address",
"is_location_input": True,
"is_search_bar": False,
}
defaults.update(overrides)
return InputOrSelectContext(**defaults)
def _make_non_location_context(**overrides: object) -> InputOrSelectContext:
defaults = {
"field": "Search",
"is_location_input": False,
"is_search_bar": False,
}
defaults.update(overrides)
return InputOrSelectContext(**defaults)
def _mock_skyvern_element(frame: MagicMock | None = None) -> MagicMock:
"""Return a mock SkyvernElement whose helpers are async-safe."""
el = MagicMock()
el.get_id.return_value = "elem-1"
el.get_frame.return_value = frame or _mock_frame()
el.get_frame_id.return_value = "frame-1"
el.is_interactable.return_value = True
el.press_fill = AsyncMock()
el.input_clear = AsyncMock()
el.is_visible = AsyncMock(return_value=True)
el.get_element_handler = AsyncMock(return_value=MagicMock())
return el
def _mock_frame(locator_count: int = 1) -> MagicMock:
"""Return a mock Playwright Frame with a configurable locator."""
frame = MagicMock()
locator = MagicMock()
locator.count = AsyncMock(return_value=locator_count)
locator.click = AsyncMock()
frame.locator.return_value = locator
return frame
def _mock_incremental_scrape(elements: list[dict]) -> MagicMock:
"""Return a mock IncrementalScrapePage that yields *elements*."""
inc = MagicMock()
inc.start_listen_dom_increment = AsyncMock()
inc.stop_listen_dom_increment = AsyncMock()
inc.get_incremental_element_tree = AsyncMock(return_value=copy.deepcopy(elements))
inc.build_html_tree.return_value = "<div>mocked</div>"
return inc
# ---------------------------------------------------------------------------
# Tests for choose_auto_completion_dropdown
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_location_single_option_skips_llm() -> None:
"""When is_location_input=True and exactly 1 option appears, the LLM must NOT be called."""
frame = _mock_frame(locator_count=1)
skyvern_el = _mock_skyvern_element(frame)
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
with (
patch(
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
),
patch(
"skyvern.webeye.actions.handler.IncrementalScrapePage",
return_value=inc_scrape,
),
patch("skyvern.webeye.actions.handler.app") as mock_app,
):
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
result = await choose_auto_completion_dropdown(
context=_make_location_context(),
page=MagicMock(),
scraped_page=MagicMock(),
dom=MagicMock(),
text="123 Main St",
skyvern_element=skyvern_el,
step=_STEP,
task=_TASK,
is_location_input=True,
)
# The LLM should never have been called
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_not_called()
# The locator should have been clicked
frame.locator.assert_called_with(f'[{SKYVERN_ID_ATTR}="AAAA"]')
frame.locator.return_value.click.assert_awaited_once()
# Result should indicate success
assert isinstance(result.action_result, ActionSuccess)
@pytest.mark.asyncio
async def test_location_whitespace_normalized_still_matches() -> None:
"""Input with extra whitespace should still match after normalization."""
frame = _mock_frame(locator_count=1)
skyvern_el = _mock_skyvern_element(frame)
# Option has single spaces, input will have double spaces
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
with (
patch(
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
),
patch(
"skyvern.webeye.actions.handler.IncrementalScrapePage",
return_value=inc_scrape,
),
patch("skyvern.webeye.actions.handler.app") as mock_app,
):
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
result = await choose_auto_completion_dropdown(
context=_make_location_context(),
page=MagicMock(),
scraped_page=MagicMock(),
dom=MagicMock(),
text="123 Main St", # Double spaces - should still match after normalization
skyvern_element=skyvern_el,
step=_STEP,
task=_TASK,
is_location_input=True,
)
# LLM should NOT be called - whitespace normalization should make it match
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_not_called()
assert isinstance(result.action_result, ActionSuccess)
@pytest.mark.asyncio
async def test_location_multiple_options_calls_llm() -> None:
"""When is_location_input=True but multiple options appear, the LLM IS called."""
frame = _mock_frame(locator_count=1)
skyvern_el = _mock_skyvern_element(frame)
inc_scrape = _mock_incremental_scrape(MULTI_ELEMENTS)
llm_response = {
"auto_completion_attempt": True,
"relevance_float": 0.95,
"id": "AAAA",
"direct_searching": False,
"reasoning": "First option matches",
}
with (
patch(
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
),
patch(
"skyvern.webeye.actions.handler.IncrementalScrapePage",
return_value=inc_scrape,
),
patch("skyvern.webeye.actions.handler.app") as mock_app,
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
):
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
mock_app.AGENT_FUNCTION = MagicMock()
mock_prompt.load_prompt.return_value = "mocked prompt"
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
await choose_auto_completion_dropdown(
context=_make_location_context(),
page=MagicMock(),
scraped_page=MagicMock(),
dom=MagicMock(),
text="123 Main St",
skyvern_element=skyvern_el,
step=_STEP,
task=_TASK,
is_location_input=True,
)
# LLM should have been called because there are 2 options
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
@pytest.mark.asyncio
async def test_non_location_single_option_calls_llm() -> None:
"""When is_location_input=False, even a single option goes through the LLM path."""
frame = _mock_frame(locator_count=1)
skyvern_el = _mock_skyvern_element(frame)
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
llm_response = {
"auto_completion_attempt": True,
"relevance_float": 0.95,
"id": "AAAA",
"direct_searching": False,
"reasoning": "Matches",
}
with (
patch(
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
),
patch(
"skyvern.webeye.actions.handler.IncrementalScrapePage",
return_value=inc_scrape,
),
patch("skyvern.webeye.actions.handler.app") as mock_app,
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
):
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
mock_app.AGENT_FUNCTION = MagicMock()
mock_prompt.load_prompt.return_value = "mocked prompt"
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
await choose_auto_completion_dropdown(
context=_make_non_location_context(),
page=MagicMock(),
scraped_page=MagicMock(),
dom=MagicMock(),
text="some search",
skyvern_element=skyvern_el,
step=_STEP,
task=_TASK,
is_location_input=False,
)
# LLM should be called — no fast-path for non-location inputs
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
@pytest.mark.asyncio
async def test_location_fast_path_returns_action_success() -> None:
"""The fast-path must set action_result to ActionSuccess on the result object."""
frame = _mock_frame(locator_count=1)
skyvern_el = _mock_skyvern_element(frame)
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
with (
patch(
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
),
patch(
"skyvern.webeye.actions.handler.IncrementalScrapePage",
return_value=inc_scrape,
),
patch("skyvern.webeye.actions.handler.app") as mock_app,
):
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
result = await choose_auto_completion_dropdown(
context=_make_location_context(),
page=MagicMock(),
scraped_page=MagicMock(),
dom=MagicMock(),
text="123 Main St",
skyvern_element=skyvern_el,
step=_STEP,
task=_TASK,
is_location_input=True,
)
assert isinstance(result, AutoCompletionResult)
assert isinstance(result.action_result, ActionSuccess)
@pytest.mark.asyncio
async def test_location_fast_path_element_not_in_dom_falls_through() -> None:
"""If the single element's locator has count 0, the fast-path is skipped."""
frame = _mock_frame(locator_count=0) # element not found in DOM
skyvern_el = _mock_skyvern_element(frame)
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
llm_response = {
"auto_completion_attempt": True,
"relevance_float": 0.95,
"id": "AAAA",
"direct_searching": False,
"reasoning": "Matches",
}
with (
patch(
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
),
patch(
"skyvern.webeye.actions.handler.IncrementalScrapePage",
return_value=inc_scrape,
),
patch("skyvern.webeye.actions.handler.app") as mock_app,
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
):
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
mock_app.AGENT_FUNCTION = MagicMock()
mock_prompt.load_prompt.return_value = "mocked prompt"
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
# Should fall through to LLM path because locator.count() == 0
await choose_auto_completion_dropdown(
context=_make_location_context(),
page=MagicMock(),
scraped_page=MagicMock(),
dom=MagicMock(),
text="123 Main St",
skyvern_element=skyvern_el,
step=_STEP,
task=_TASK,
is_location_input=True,
)
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
# ---------------------------------------------------------------------------
# Tests for input_or_auto_complete_input flag propagation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_input_or_auto_complete_passes_is_location_input() -> None:
"""input_or_auto_complete_input must forward is_location_input to choose_auto_completion_dropdown."""
context = _make_location_context()
with patch(
"skyvern.webeye.actions.handler.choose_auto_completion_dropdown",
new=AsyncMock(return_value=AutoCompletionResult(action_result=ActionSuccess())),
) as mock_choose:
result = await input_or_auto_complete_input(
input_or_select_context=context,
scraped_page=MagicMock(),
page=MagicMock(),
dom=MagicMock(),
text="123 Main St",
skyvern_element=_mock_skyvern_element(),
step=_STEP,
task=_TASK,
)
assert isinstance(result, ActionSuccess)
# Verify is_location_input was passed
call_kwargs = mock_choose.call_args.kwargs
assert call_kwargs["is_location_input"] is True
@pytest.mark.asyncio
async def test_input_or_auto_complete_passes_false_for_non_location() -> None:
"""When is_location_input is None/False, the flag should be passed as False."""
context = _make_non_location_context()
with patch(
"skyvern.webeye.actions.handler.choose_auto_completion_dropdown",
new=AsyncMock(return_value=AutoCompletionResult(action_result=ActionSuccess())),
) as mock_choose:
result = await input_or_auto_complete_input(
input_or_select_context=context,
scraped_page=MagicMock(),
page=MagicMock(),
dom=MagicMock(),
text="some query",
skyvern_element=_mock_skyvern_element(),
step=_STEP,
task=_TASK,
)
assert isinstance(result, ActionSuccess)
call_kwargs = mock_choose.call_args.kwargs
assert call_kwargs["is_location_input"] is False
# ---------------------------------------------------------------------------
# Integration tests: options that don't contain the input fall through to LLM
# ---------------------------------------------------------------------------
NO_RESULT_ELEMENTS = [{"id": "AAAA", "tag": "div", "text": "No results"}]
UNRELATED_ELEMENTS = [{"id": "AAAA", "tag": "div", "text": "Something completely different"}]
@pytest.mark.asyncio
async def test_location_no_results_option_falls_through_to_llm() -> None:
"""When the single option doesn't contain the input text, fall through to LLM."""
frame = _mock_frame(locator_count=1)
skyvern_el = _mock_skyvern_element(frame)
inc_scrape = _mock_incremental_scrape(NO_RESULT_ELEMENTS)
llm_response = {
"auto_completion_attempt": False,
"relevance_float": 0.0,
"id": "",
"direct_searching": True,
"reasoning": "No results shown",
}
with (
patch(
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
),
patch(
"skyvern.webeye.actions.handler.IncrementalScrapePage",
return_value=inc_scrape,
),
patch("skyvern.webeye.actions.handler.app") as mock_app,
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
):
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
mock_app.AGENT_FUNCTION = MagicMock()
mock_prompt.load_prompt.return_value = "mocked prompt"
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
await choose_auto_completion_dropdown(
context=_make_location_context(),
page=MagicMock(),
scraped_page=MagicMock(),
dom=MagicMock(),
text="123 Main St",
skyvern_element=skyvern_el,
step=_STEP,
task=_TASK,
is_location_input=True,
)
# LLM should be called because "No results" doesn't contain "123 Main St"
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
@pytest.mark.asyncio
async def test_location_unrelated_option_falls_through_to_llm() -> None:
"""When the single option doesn't contain the input text, fall through to LLM."""
frame = _mock_frame(locator_count=1)
skyvern_el = _mock_skyvern_element(frame)
inc_scrape = _mock_incremental_scrape(UNRELATED_ELEMENTS)
llm_response = {
"auto_completion_attempt": True,
"relevance_float": 0.5,
"id": "AAAA",
"direct_searching": False,
"reasoning": "Only option available",
}
with (
patch(
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
),
patch(
"skyvern.webeye.actions.handler.IncrementalScrapePage",
return_value=inc_scrape,
),
patch("skyvern.webeye.actions.handler.app") as mock_app,
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
):
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
mock_app.AGENT_FUNCTION = MagicMock()
mock_prompt.load_prompt.return_value = "mocked prompt"
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
await choose_auto_completion_dropdown(
context=_make_location_context(),
page=MagicMock(),
scraped_page=MagicMock(),
dom=MagicMock(),
text="123 Main St",
skyvern_element=skyvern_el,
step=_STEP,
task=_TASK,
is_location_input=True,
)
# LLM should be called because option doesn't contain the input
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()

View File

@@ -0,0 +1,265 @@
"""
Tests for batch action query correctness in transform_workflow_run.py.
Verifies that the transform layer produces chronologically ordered actions
per task for script generation, even though get_tasks_actions returns
descending order (for the timeline UI).
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
from skyvern.webeye.actions.actions import ClickAction, ExtractAction, InputTextAction
def _make_action(
action_cls: type,
action_id: str,
task_id: str,
element_id: str | None = None,
**kwargs: object,
) -> MagicMock:
"""Create a real Action instance for use in tests."""
action = action_cls(
action_id=action_id,
task_id=task_id,
element_id=element_id
if element_id is not None
else ("elem_" + action_id if action_cls != ExtractAction else None),
**({"text": "hello"} if action_cls == InputTextAction else {}),
**kwargs,
)
return action
@pytest.mark.asyncio
async def test_batch_actions_preserve_per_task_ordering() -> None:
"""
Regression test: transform_workflow_run must produce actions in ascending
chronological order per task for script generation.
get_tasks_actions returns DESC order (for timeline UI). The transform
layer reverses to ASC. This test mocks DESC input and verifies ASC output.
"""
mock_workflow_run_resp = MagicMock()
mock_workflow_run_resp.run_request = MagicMock()
mock_workflow_run_resp.run_request.workflow_id = "wpid_test"
mock_workflow_run_resp.run_request.model_dump = MagicMock(
return_value={"workflow_id": "wpid_test", "parameters": {}}
)
def_block_a = MagicMock()
def_block_a.block_type = "task"
def_block_a.label = "block_a"
def_block_a.model_dump = MagicMock(return_value={"block_type": "task", "label": "block_a"})
def_block_b = MagicMock()
def_block_b.block_type = "task"
def_block_b.label = "block_b"
def_block_b.model_dump = MagicMock(return_value={"block_type": "task", "label": "block_b"})
mock_workflow = MagicMock()
mock_workflow.model_dump = MagicMock(return_value={"workflow_id": "wf_1"})
mock_workflow.workflow_definition.blocks = [def_block_a, def_block_b]
run_block_a = MagicMock()
run_block_a.workflow_run_block_id = "wfrb_a"
run_block_a.parent_workflow_run_block_id = None
run_block_a.block_type = "task"
run_block_a.label = "block_a"
run_block_a.task_id = "task_a"
run_block_a.status = "completed"
run_block_a.output = {}
run_block_a.created_at = 1
run_block_b = MagicMock()
run_block_b.workflow_run_block_id = "wfrb_b"
run_block_b.parent_workflow_run_block_id = None
run_block_b.block_type = "task"
run_block_b.label = "block_b"
run_block_b.task_id = "task_b"
run_block_b.status = "completed"
run_block_b.output = {}
run_block_b.created_at = 2
mock_task_a = MagicMock()
mock_task_a.task_id = "task_a"
mock_task_a.model_dump = MagicMock(return_value={"task_id": "task_a"})
mock_task_b = MagicMock()
mock_task_b.task_id = "task_b"
mock_task_b.model_dump = MagicMock(return_value={"task_id": "task_b"})
# Actions in chronological order:
# task_a: click (t=1), input_text (t=3)
# task_b: click (t=2), extract (t=4)
action_a_click = _make_action(ClickAction, action_id="a_click", task_id="task_a", element_id="el_1")
action_b_click = _make_action(ClickAction, action_id="b_click", task_id="task_b", element_id="el_2")
action_a_input = _make_action(InputTextAction, action_id="a_input", task_id="task_a", element_id="el_3")
action_b_extract = _make_action(ExtractAction, action_id="b_extract", task_id="task_b", element_id=None)
# get_tasks_actions returns DESC order (newest first) — matching real DB behavior
all_actions_descending = [action_b_extract, action_a_input, action_b_click, action_a_click]
with (
patch("skyvern.services.workflow_service.get_workflow_run_response", new_callable=AsyncMock) as mock_get_wfr,
patch("skyvern.core.script_generations.transform_workflow_run.app") as mock_app,
):
mock_get_wfr.return_value = mock_workflow_run_resp
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block_a, run_block_b])
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task_a, mock_task_b])
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=all_actions_descending)
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
# After reverse, task_a actions must be in chronological order: click then input_text
task_a_actions = result.actions_by_task["task_a"]
task_a_ids = [a["action_id"] for a in task_a_actions]
assert task_a_ids == ["a_click", "a_input"], f"task_a actions out of order: {task_a_ids}"
assert task_a_actions[0]["action_type"] == "click"
assert task_a_actions[1]["action_type"] == "input_text"
# task_b actions must be in chronological order: click then extract
task_b_actions = result.actions_by_task["task_b"]
task_b_ids = [a["action_id"] for a in task_b_actions]
assert task_b_ids == ["b_click", "b_extract"], f"task_b actions out of order: {task_b_ids}"
assert task_b_actions[0]["action_type"] == "click"
assert task_b_actions[1]["action_type"] == "extract"
# No cross-contamination between tasks
assert set(task_a_ids) == {"a_click", "a_input"}
assert set(task_b_ids) == {"b_click", "b_extract"}
@pytest.mark.asyncio
async def test_batch_actions_without_reverse_would_be_wrong() -> None:
"""
Prove that without the reverse() call, DESC input from get_tasks_actions
would produce wrong ordering in script generation output.
If someone removes the reverse(), this test catches it.
"""
mock_workflow_run_resp = MagicMock()
mock_workflow_run_resp.run_request = MagicMock()
mock_workflow_run_resp.run_request.workflow_id = "wpid_test"
mock_workflow_run_resp.run_request.model_dump = MagicMock(
return_value={"workflow_id": "wpid_test", "parameters": {}}
)
def_block = MagicMock()
def_block.block_type = "task"
def_block.label = "my_block"
def_block.model_dump = MagicMock(return_value={"block_type": "task", "label": "my_block"})
mock_workflow = MagicMock()
mock_workflow.model_dump = MagicMock(return_value={"workflow_id": "wf_1"})
mock_workflow.workflow_definition.blocks = [def_block]
run_block = MagicMock()
run_block.workflow_run_block_id = "wfrb_1"
run_block.parent_workflow_run_block_id = None
run_block.block_type = "task"
run_block.label = "my_block"
run_block.task_id = "task_1"
run_block.status = "completed"
run_block.output = {}
run_block.created_at = 1
mock_task = MagicMock()
mock_task.task_id = "task_1"
mock_task.model_dump = MagicMock(return_value={"task_id": "task_1"})
# Chronological order: click -> input -> extract
# DB returns DESC: extract -> input -> click
action_click = _make_action(ClickAction, action_id="act_1_click", task_id="task_1", element_id="el_1")
action_input = _make_action(InputTextAction, action_id="act_2_input", task_id="task_1", element_id="el_2")
action_extract = _make_action(ExtractAction, action_id="act_3_extract", task_id="task_1", element_id=None)
# DESC order from DB (newest first)
actions_descending = [action_extract, action_input, action_click]
with (
patch("skyvern.services.workflow_service.get_workflow_run_response", new_callable=AsyncMock) as mock_get_wfr,
patch("skyvern.core.script_generations.transform_workflow_run.app") as mock_app,
):
mock_get_wfr.return_value = mock_workflow_run_resp
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=actions_descending)
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
# After reverse, output must be chronological: click, input, extract
actions = result.actions_by_task["task_1"]
action_ids = [a["action_id"] for a in actions]
assert action_ids == ["act_1_click", "act_2_input", "act_3_extract"], (
f"Actions should be in chronological order after reverse, got: {action_ids}"
)
@pytest.mark.asyncio
async def test_batch_actions_preserve_none_element_id() -> None:
"""
Regression test: hydrate_action must be called WITHOUT empty_element_id=True,
so that None element_ids remain None (matching get_task_actions_hydrated behavior).
Previously get_tasks_actions used hydrate_action(action, empty_element_id=True)
which silently converted None element_ids to empty strings.
"""
mock_workflow_run_resp = MagicMock()
mock_workflow_run_resp.run_request = MagicMock()
mock_workflow_run_resp.run_request.workflow_id = "wpid_test"
mock_workflow_run_resp.run_request.model_dump = MagicMock(
return_value={"workflow_id": "wpid_test", "parameters": {}}
)
def_block = MagicMock()
def_block.block_type = "extraction"
def_block.label = "extract_block"
def_block.model_dump = MagicMock(return_value={"block_type": "extraction", "label": "extract_block"})
mock_workflow = MagicMock()
mock_workflow.model_dump = MagicMock(return_value={"workflow_id": "wf_1"})
mock_workflow.workflow_definition.blocks = [def_block]
run_block = MagicMock()
run_block.workflow_run_block_id = "wfrb_1"
run_block.parent_workflow_run_block_id = None
run_block.block_type = "extraction"
run_block.label = "extract_block"
run_block.task_id = "task_1"
run_block.status = "completed"
run_block.output = {}
run_block.created_at = 1
mock_task = MagicMock()
mock_task.task_id = "task_1"
mock_task.model_dump = MagicMock(return_value={"task_id": "task_1"})
# ExtractAction has element_id=None (extracts don't target a specific element)
action_extract = _make_action(ExtractAction, action_id="act_extract", task_id="task_1", element_id=None)
assert action_extract.element_id is None
with (
patch("skyvern.services.workflow_service.get_workflow_run_response", new_callable=AsyncMock) as mock_get_wfr,
patch("skyvern.core.script_generations.transform_workflow_run.app") as mock_app,
):
mock_get_wfr.return_value = mock_workflow_run_resp
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=[action_extract])
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
actions = result.actions_by_task["task_1"]
assert len(actions) == 1
# element_id must remain None, NOT converted to ""
assert actions[0]["element_id"] is None, (
f"element_id should be None but got {actions[0]['element_id']!r}. "
"This indicates hydrate_action was called with empty_element_id=True"
)

View File

@@ -0,0 +1,153 @@
from __future__ import annotations
import pytest
from skyvern.config import settings
from skyvern.forge.sdk.workflow.exceptions import FailedToFormatJinjaStyleParameter, MissingJinjaVariables
from skyvern.forge.sdk.workflow.models.block import BranchEvaluationContext, JinjaBranchCriteria
class FakeWorkflowRunContext:
def __init__(
self,
*,
values: dict,
secrets: dict | None = None,
include_secrets_in_templates: bool = False,
block_metadata: dict[str, dict] | None = None,
) -> None:
self.values = dict(values)
self.secrets = secrets or {}
self.include_secrets_in_templates = include_secrets_in_templates
self._block_metadata = block_metadata or {}
# Minimal workflow identifiers
self.workflow_title = "wf-title"
self.workflow_id = "wf-id"
self.workflow_permanent_id = "wf-perm-id"
self.workflow_run_id = "wf-run-id"
def get_block_metadata(self, label: str) -> dict:
return dict(self._block_metadata.get(label, {}))
@pytest.mark.asyncio
async def test_jinja_branch_criteria_evaluates_truthy_with_workflow_context():
fake_ctx = FakeWorkflowRunContext(
values={"params": {"foo": "bar"}, "extra": "value"},
block_metadata={"conditional": {"current_index": 1, "custom": "meta"}},
)
branch_ctx = BranchEvaluationContext(
workflow_run_context=fake_ctx, # ensures template_data matches block parameter rendering
block_label="conditional",
)
criteria = JinjaBranchCriteria(expression="{{ params.foo == 'bar' and current_index == 1 }}")
assert await criteria.evaluate(branch_ctx) is True
@pytest.mark.asyncio
async def test_jinja_branch_criteria_raises_on_missing_variable_strict(monkeypatch):
monkeypatch.setattr(settings, "WORKFLOW_TEMPLATING_STRICTNESS", "strict")
branch_ctx = BranchEvaluationContext()
criteria = JinjaBranchCriteria(expression="{{ missing_value }}")
with pytest.raises(MissingJinjaVariables):
await criteria.evaluate(branch_ctx)
@pytest.mark.asyncio
async def test_jinja_branch_criteria_raises_on_template_error():
branch_ctx = BranchEvaluationContext()
criteria = JinjaBranchCriteria(expression="{% for %}") # invalid Jinja syntax
with pytest.raises(FailedToFormatJinjaStyleParameter):
await criteria.evaluate(branch_ctx)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"expression,expected",
[
# Boolean-like strings (case insensitive)
("{{ 'true' }}", True),
("{{ 'True' }}", True),
("{{ 'TRUE' }}", True),
("{{ 'false' }}", False),
("{{ 'False' }}", False),
("{{ 'FALSE' }}", False),
# Numeric strings
("{{ '1' }}", True),
("{{ '0' }}", False),
("{{ '42' }}", True),
("{{ '-1' }}", True),
("{{ '0.0' }}", False),
("{{ '0.1' }}", True),
("{{ '-0.5' }}", True),
# Yes/No variants
("{{ 'yes' }}", True),
("{{ 'Yes' }}", True),
("{{ 'YES' }}", True),
("{{ 'y' }}", True),
("{{ 'Y' }}", True),
("{{ 'no' }}", False),
("{{ 'No' }}", False),
("{{ 'NO' }}", False),
("{{ 'n' }}", False),
("{{ 'N' }}", False),
# On/Off
("{{ 'on' }}", True),
("{{ 'ON' }}", True),
("{{ 'off' }}", False),
("{{ 'OFF' }}", False),
# Null variants
("{{ 'null' }}", False),
("{{ 'Null' }}", False),
("{{ 'NULL' }}", False),
("{{ 'none' }}", False),
("{{ 'None' }}", False),
# Empty and whitespace
("{{ '' }}", False),
("{{ ' ' }}", False),
("{{ '\t\n' }}", False),
# Arbitrary strings (non-empty = truthy)
("{{ 'some text' }}", True),
("{{ 'anything' }}", True),
# Direct boolean comparisons (common use case)
("{{ 5 > 3 }}", True),
("{{ 1 == 0 }}", False),
],
)
async def test_jinja_branch_criteria_truthy_falsy_evaluation(expression: str, expected: bool):
"""Test that rendered template strings are properly evaluated as boolean."""
fake_ctx = FakeWorkflowRunContext(values={})
branch_ctx = BranchEvaluationContext(workflow_run_context=fake_ctx, block_label="test")
criteria = JinjaBranchCriteria(expression=expression)
result = await criteria.evaluate(branch_ctx)
assert result is expected, f"Expression {expression} should evaluate to {expected}, got {result}"
@pytest.mark.asyncio
async def test_jinja_branch_criteria_with_variable_comparison():
"""Test realistic scenario with variable comparisons."""
fake_ctx = FakeWorkflowRunContext(
values={
"comment_count": 150,
"threshold": 100,
"status": "active",
}
)
branch_ctx = BranchEvaluationContext(workflow_run_context=fake_ctx, block_label="test")
# Numeric comparison
criteria = JinjaBranchCriteria(expression="{{ comment_count > threshold }}")
assert await criteria.evaluate(branch_ctx) is True
# String comparison
criteria = JinjaBranchCriteria(expression="{{ status == 'active' }}")
assert await criteria.evaluate(branch_ctx) is True
# Combined logic
criteria = JinjaBranchCriteria(expression="{{ comment_count > threshold and status == 'active' }}")
assert await criteria.evaluate(branch_ctx) is True

View File

@@ -0,0 +1,119 @@
"""Unit tests for bulk artifact creation functionality."""
import pytest
from skyvern.forge.sdk.artifact.manager import ArtifactBatchData, BulkArtifactCreationRequest
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.db.models import ArtifactModel
def test_artifact_batch_data_with_data():
"""Test ArtifactBatchData with data field."""
model = ArtifactModel(
artifact_id="test-1",
artifact_type=ArtifactType.SCREENSHOT_LLM,
uri="s3://bucket/test",
organization_id="org-1",
)
batch_data = ArtifactBatchData(
artifact_model=model,
data=b"test data",
)
assert batch_data.artifact_model == model
assert batch_data.data == b"test data"
assert batch_data.path is None
def test_artifact_batch_data_with_path():
"""Test ArtifactBatchData with path field."""
model = ArtifactModel(
artifact_id="test-1",
artifact_type=ArtifactType.SCREENSHOT_LLM,
uri="s3://bucket/test",
organization_id="org-1",
)
batch_data = ArtifactBatchData(
artifact_model=model,
path="/tmp/test.png",
)
assert batch_data.artifact_model == model
assert batch_data.data is None
assert batch_data.path == "/tmp/test.png"
def test_artifact_batch_data_with_both_raises_error():
"""Test that ArtifactBatchData raises error when both data and path are provided."""
model = ArtifactModel(
artifact_id="test-1",
artifact_type=ArtifactType.SCREENSHOT_LLM,
uri="s3://bucket/test",
organization_id="org-1",
)
with pytest.raises(ValueError, match="Cannot specify both data and path"):
ArtifactBatchData(
artifact_model=model,
data=b"test data",
path="/tmp/test.png",
)
def test_bulk_artifact_creation_request():
"""Test BulkArtifactCreationRequest structure."""
model1 = ArtifactModel(
artifact_id="test-1",
artifact_type=ArtifactType.LLM_PROMPT,
uri="s3://bucket/test1",
organization_id="org-1",
)
model2 = ArtifactModel(
artifact_id="test-2",
artifact_type=ArtifactType.SCREENSHOT_LLM,
uri="s3://bucket/test2",
organization_id="org-1",
)
request = BulkArtifactCreationRequest(
artifacts=[
ArtifactBatchData(artifact_model=model1, data=b"data1"),
ArtifactBatchData(artifact_model=model2, data=b"data2"),
],
primary_key="task-123",
)
assert len(request.artifacts) == 2
assert request.primary_key == "task-123"
assert request.artifacts[0].artifact_model.artifact_id == "test-1"
assert request.artifacts[1].artifact_model.artifact_id == "test-2"
def test_bulk_artifact_creation_performance_benefit():
"""
Test to verify that bulk creation reduces database calls.
This is a conceptual test to document the performance improvement.
"""
# Before optimization: Creating N artifacts = N database INSERT calls
# After optimization: Creating N artifacts = 1 bulk INSERT call
num_artifacts = 10
# Simulate old approach (N individual inserts)
individual_insert_count = num_artifacts
# Simulate new approach (1 bulk insert)
bulk_insert_count = 1
# Assert that bulk insert is more efficient
assert bulk_insert_count < individual_insert_count
# The reduction ratio
reduction_ratio = individual_insert_count / bulk_insert_count
assert reduction_ratio == num_artifacts
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,343 @@
"""
Tests for click prompt parameterization in cached script generation.
When generating cached scripts, click action prompts (intention/reasoning) should
replace literal parameter values with f-string references to context.parameters[...],
so that re-runs with different values produce correct behavior.
"""
from typing import Any
import libcst as cst
from skyvern.core.script_generations.generate_script import (
MIN_PARAM_VALUE_LENGTH_FOR_PROMPT_SUB,
_action_to_stmt,
_build_parameterized_prompt_cst,
_build_value_to_param_lookup,
)
from skyvern.webeye.actions.actions import ActionType
def _make_action(
action_type: str,
field_name: str | None = None,
text: str = "",
option: str = "",
file_url: str = "",
) -> dict[str, Any]:
action: dict[str, Any] = {"action_type": action_type}
if field_name:
action["field_name"] = field_name
if text:
action["text"] = text
if option:
action["option"] = option
if file_url:
action["file_url"] = file_url
return action
# ---------------------------------------------------------------------------
# _build_value_to_param_lookup
# ---------------------------------------------------------------------------
class TestBuildValueToParamLookup:
def test_collects_input_text_values(self) -> None:
actions_by_task = {
"task-1": [
_make_action(ActionType.INPUT_TEXT, field_name="patient_id", text="542-641-668"),
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert lookup == {"542-641-668": "patient_id"}
def test_collects_select_option_values(self) -> None:
actions_by_task = {
"task-1": [
_make_action(ActionType.SELECT_OPTION, field_name="state", option="California"),
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert lookup == {"California": "state"}
def test_collects_upload_file_values(self) -> None:
actions_by_task = {
"task-1": [
_make_action(
ActionType.UPLOAD_FILE,
field_name="document",
file_url="https://example.com/report.pdf",
),
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert lookup == {"https://example.com/report.pdf": "document"}
def test_skips_actions_without_field_name(self) -> None:
actions_by_task = {
"task-1": [
_make_action(ActionType.INPUT_TEXT, text="some value without field name"),
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert lookup == {}
def test_skips_short_values(self) -> None:
actions_by_task = {
"task-1": [
_make_action(ActionType.INPUT_TEXT, field_name="flag", text="No"),
_make_action(ActionType.INPUT_TEXT, field_name="code", text="CA"),
_make_action(ActionType.INPUT_TEXT, field_name="num", text="1"),
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert lookup == {}
def test_boundary_value_at_min_length(self) -> None:
"""Values at exactly MIN_PARAM_VALUE_LENGTH_FOR_PROMPT_SUB should be included."""
value = "x" * MIN_PARAM_VALUE_LENGTH_FOR_PROMPT_SUB
actions_by_task = {
"task-1": [
_make_action(ActionType.INPUT_TEXT, field_name="field", text=value),
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert value in lookup
def test_sorted_by_descending_length(self) -> None:
actions_by_task = {
"task-1": [
_make_action(ActionType.INPUT_TEXT, field_name="short_field", text="abcd"),
_make_action(ActionType.INPUT_TEXT, field_name="long_field", text="abcdefghij"),
_make_action(ActionType.INPUT_TEXT, field_name="mid_field", text="abcdef"),
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
keys = list(lookup.keys())
assert keys == ["abcdefghij", "abcdef", "abcd"]
def test_first_writer_wins_on_duplicate_values(self) -> None:
actions_by_task = {
"task-1": [
_make_action(ActionType.INPUT_TEXT, field_name="first_field", text="same-value"),
_make_action(ActionType.INPUT_TEXT, field_name="second_field", text="same-value"),
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert lookup["same-value"] == "first_field"
def test_skips_click_actions(self) -> None:
actions_by_task = {
"task-1": [
{"action_type": ActionType.CLICK, "field_name": "click_field", "text": "some text"},
]
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert lookup == {}
def test_multiple_tasks(self) -> None:
actions_by_task = {
"task-1": [
_make_action(ActionType.INPUT_TEXT, field_name="patient_id", text="542-641-668"),
],
"task-2": [
_make_action(ActionType.INPUT_TEXT, field_name="doctor_name", text="Dr. Smith"),
],
}
lookup = _build_value_to_param_lookup(actions_by_task)
assert lookup == {"542-641-668": "patient_id", "Dr. Smith": "doctor_name"}
def test_empty_actions(self) -> None:
lookup = _build_value_to_param_lookup({})
assert lookup == {}
# ---------------------------------------------------------------------------
# _build_parameterized_prompt_cst
# ---------------------------------------------------------------------------
class TestBuildParameterizedPromptCst:
def test_returns_none_when_no_matches(self) -> None:
result = _build_parameterized_prompt_cst(
"Click the submit button",
{"542-641-668": "patient_id"},
)
assert result is None
def test_single_substitution(self) -> None:
result = _build_parameterized_prompt_cst(
"Which card corresponds to the referral for ID 542-641-668?",
{"542-641-668": "patient_id"},
)
assert result is not None
assert isinstance(result, cst.FormattedString)
code = cst.Module(body=[]).code_for_node(result)
assert "context.parameters" in code
assert "patient_id" in code
assert "542-641-668" not in code
def test_multiple_substitutions(self) -> None:
result = _build_parameterized_prompt_cst(
"Find patient 542-641-668 with doctor Dr. Smith",
{"542-641-668": "patient_id", "Dr. Smith": "doctor_name"},
)
assert result is not None
code = cst.Module(body=[]).code_for_node(result)
assert "patient_id" in code
assert "doctor_name" in code
assert "542-641-668" not in code
assert "Dr. Smith" not in code
def test_substitution_at_start(self) -> None:
result = _build_parameterized_prompt_cst(
"542-641-668 is the patient ID to search for",
{"542-641-668": "patient_id"},
)
assert result is not None
parts = result.parts
# First part should be the expression (substitution at start)
assert isinstance(parts[0], cst.FormattedStringExpression)
def test_substitution_at_end(self) -> None:
result = _build_parameterized_prompt_cst(
"Search for patient 542-641-668",
{"542-641-668": "patient_id"},
)
assert result is not None
parts = result.parts
# Last part should be the expression (substitution at end)
assert isinstance(parts[-1], cst.FormattedStringExpression)
def test_empty_intention(self) -> None:
result = _build_parameterized_prompt_cst("", {"542-641-668": "patient_id"})
assert result is None
def test_empty_lookup(self) -> None:
result = _build_parameterized_prompt_cst(
"Which card corresponds to the referral for ID 542-641-668?",
{},
)
assert result is None
def test_longer_match_preferred_over_shorter(self) -> None:
"""When values overlap, the longer value (sorted first) takes precedence."""
result = _build_parameterized_prompt_cst(
"Enter 542-641-668-999 here",
{
"542-641-668-999": "full_id",
"542-641-668": "partial_id",
},
)
assert result is not None
code = cst.Module(body=[]).code_for_node(result)
assert "full_id" in code
assert "partial_id" not in code
def test_generates_valid_fstring_syntax(self) -> None:
"""The generated f-string should be parseable Python."""
result = _build_parameterized_prompt_cst(
"Which card area corresponds to ID 542-641-668?",
{"542-641-668": "patient_id"},
)
assert result is not None
code = cst.Module(body=[]).code_for_node(result)
# Should be a valid f-string — verify it starts with f" or f'
assert code.startswith("f'") or code.startswith('f"')
# The full expression should be compilable
compile(code, "<test>", "eval")
def test_repeated_value_in_intention(self) -> None:
"""If the same value appears twice, both occurrences should be replaced."""
result = _build_parameterized_prompt_cst(
"Compare 542-641-668 with 542-641-668",
{"542-641-668": "patient_id"},
)
assert result is not None
code = cst.Module(body=[]).code_for_node(result)
# The literal should not appear at all
assert "542-641-668" not in code
# context.parameters should appear twice (once per occurrence)
assert code.count("context.parameters") == 2
# ---------------------------------------------------------------------------
# Integration: _action_to_stmt with value_to_param
# ---------------------------------------------------------------------------
class TestActionToStmtClickParameterization:
"""End-to-end tests exercising _action_to_stmt for click actions."""
def _render(self, stmt: cst.BaseStatement) -> str:
return cst.Module(body=[stmt]).code
def test_click_prompt_parameterized(self) -> None:
"""Click action with matching value in intention gets an f-string prompt."""
act: dict[str, Any] = {
"action_type": "click",
"xpath": "//div[@class='card']",
"intention": "Which card corresponds to the referral for ID 542-641-668?",
}
task: dict[str, Any] = {}
value_to_param = {"542-641-668": "patient_id"}
stmt = _action_to_stmt(act, task, value_to_param=value_to_param)
code = self._render(stmt)
assert "context.parameters" in code
assert "patient_id" in code
assert "542-641-668" not in code
# Should be an f-string
assert "f'" in code or 'f"' in code
def test_click_prompt_literal_when_no_lookup(self) -> None:
"""Click action without value_to_param produces a plain string prompt."""
act: dict[str, Any] = {
"action_type": "click",
"xpath": "//div[@class='card']",
"intention": "Click the submit button",
}
task: dict[str, Any] = {}
stmt = _action_to_stmt(act, task, value_to_param=None)
code = self._render(stmt)
assert "Click the submit button" in code
assert "context.parameters" not in code
def test_click_prompt_literal_when_no_match(self) -> None:
"""Click action with non-matching lookup produces a plain string prompt."""
act: dict[str, Any] = {
"action_type": "click",
"xpath": "//button",
"intention": "Click the submit button",
}
task: dict[str, Any] = {}
value_to_param = {"542-641-668": "patient_id"}
stmt = _action_to_stmt(act, task, value_to_param=value_to_param)
code = self._render(stmt)
assert "Click the submit button" in code
assert "context.parameters" not in code
def test_fill_action_unaffected_by_value_to_param(self) -> None:
"""Fill actions should still use the field_name mechanism, not value_to_param."""
act: dict[str, Any] = {
"action_type": "input_text",
"xpath": "//input[@name='search']",
"text": "542-641-668",
"field_name": "patient_id",
}
task: dict[str, Any] = {}
value_to_param = {"542-641-668": "patient_id"}
stmt = _action_to_stmt(act, task, value_to_param=value_to_param)
code = self._render(stmt)
# Should use context.parameters via the field_name mechanism, not f-string
assert "context.parameters" in code
assert "patient_id" in code

View File

@@ -0,0 +1,545 @@
"""Tests for compute_conditional_scopes() function.
This function maps each block label to the conditional block label whose scope it belongs to.
It handles merge-point detection, nested conditionals, and deduplication of branch targets.
"""
from __future__ import annotations
from datetime import datetime, timezone
from skyvern.forge.sdk.workflow.models.block import (
BranchCondition,
ConditionalBlock,
HttpRequestBlock,
TaskBlock,
compute_conditional_scopes,
)
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
def _make_output_parameter(key: str) -> OutputParameter:
now = datetime.now(tz=timezone.utc)
return OutputParameter(
key=key,
parameter_type="output",
output_parameter_id=f"op_{key}",
workflow_id="wf_test",
created_at=now,
modified_at=now,
)
def _make_task_block(label: str, *, next_block_label: str | None = None) -> TaskBlock:
return TaskBlock(
label=label,
url="https://example.com",
output_parameter=_make_output_parameter(label),
next_block_label=next_block_label,
)
def _make_http_block(label: str, *, next_block_label: str | None = None) -> HttpRequestBlock:
return HttpRequestBlock(
label=label,
url="https://example.com",
method="GET",
output_parameter=_make_output_parameter(label),
next_block_label=next_block_label,
)
def _make_conditional_block(
label: str,
branches: list[tuple[str | None, bool]],
*,
next_block_label: str | None = None,
) -> ConditionalBlock:
"""Create a conditional block with the given branches.
Args:
label: Block label
branches: List of (next_block_label, is_default) tuples
next_block_label: Default next block for the conditional itself (usually None)
"""
branch_conditions = []
for target, is_default in branches:
if is_default:
branch_conditions.append(BranchCondition(next_block_label=target, is_default=True))
else:
branch_conditions.append(
BranchCondition(
next_block_label=target,
criteria={"criteria_type": "jinja2_template", "expression": "{{ true }}"},
)
)
return ConditionalBlock(
label=label,
output_parameter=_make_output_parameter(label),
branch_conditions=branch_conditions,
next_block_label=next_block_label,
)
class TestComputeConditionalScopes:
"""Tests for compute_conditional_scopes()."""
def test_simple_two_branch_conditional_with_merge(self):
"""Test a simple conditional with two branches that merge.
Workflow:
Conditional(C) -> Branch1 -> A -> MergePoint(M)
-> Branch2 -> B -> M
Expected: A and B are scoped to C, M is NOT scoped (merge point).
"""
block_a = _make_task_block("A", next_block_label="M")
block_b = _make_task_block("B", next_block_label="M")
block_m = _make_task_block("M")
cond = _make_conditional_block("C", [("A", False), ("B", True)])
label_to_block = {
"C": cond,
"A": block_a,
"B": block_b,
"M": block_m,
}
default_next_map = {
"C": None,
"A": "M",
"B": "M",
"M": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
assert scopes == {"A": "C", "B": "C"}
assert "M" not in scopes # M is a merge point
def test_conditional_with_chain_before_merge(self):
"""Test branches with multiple blocks before merge point.
Workflow:
Conditional(C) -> Branch1 -> A -> B -> MergePoint(M)
-> Branch2 -> D -> M
Expected: A, B, D are scoped to C. M is NOT scoped.
"""
block_a = _make_task_block("A", next_block_label="B")
block_b = _make_task_block("B", next_block_label="M")
block_d = _make_task_block("D", next_block_label="M")
block_m = _make_task_block("M")
cond = _make_conditional_block("C", [("A", False), ("D", True)])
label_to_block = {
"C": cond,
"A": block_a,
"B": block_b,
"D": block_d,
"M": block_m,
}
default_next_map = {
"C": None,
"A": "B",
"B": "M",
"D": "M",
"M": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
assert scopes == {"A": "C", "B": "C", "D": "C"}
assert "M" not in scopes
def test_conditional_with_terminal_branches(self):
"""Test branches that don't merge (terminate independently).
Workflow:
Conditional(C) -> Branch1 -> A (terminal)
-> Branch2 -> B (terminal)
Expected: A and B are scoped to C since they don't appear in all branches.
"""
block_a = _make_task_block("A")
block_b = _make_task_block("B")
cond = _make_conditional_block("C", [("A", False), ("B", True)])
label_to_block = {
"C": cond,
"A": block_a,
"B": block_b,
}
default_next_map = {
"C": None,
"A": None,
"B": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
assert scopes == {"A": "C", "B": "C"}
def test_conditional_all_branches_terminal_none(self):
"""Test when all branches have None as target (no blocks to scope).
Workflow:
Conditional(C) -> Branch1 -> None
-> Branch2 -> None
Expected: No scopes (no blocks in the branches).
"""
cond = _make_conditional_block("C", [(None, False), (None, True)])
label_to_block = {"C": cond}
default_next_map = {"C": None}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
assert scopes == {}
def test_multiple_branches_same_target_deduplication(self):
"""Test that duplicate branch targets are deduplicated.
Workflow:
Conditional(C) -> Branch1 -> A -> M
-> Branch2 -> A -> M (same as Branch1)
-> Branch3 -> B -> M
With deduplication, unique targets are [A, B], so num_branches = 2.
Both chains go to M, so M is a merge point.
A appears in only one chain (after dedup), B in another.
"""
block_a = _make_task_block("A", next_block_label="M")
block_b = _make_task_block("B", next_block_label="M")
block_m = _make_task_block("M")
cond = _make_conditional_block("C", [("A", False), ("A", False), ("B", True)])
label_to_block = {
"C": cond,
"A": block_a,
"B": block_b,
"M": block_m,
}
default_next_map = {
"C": None,
"A": "M",
"B": "M",
"M": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
# A and B are scoped to C, M is the merge point
assert scopes == {"A": "C", "B": "C"}
assert "M" not in scopes
def test_nested_conditionals(self):
"""Test nested conditionals (conditional inside another's branch).
Workflow:
OuterCond(C1) -> Branch1 -> InnerCond(C2) -> BranchA -> X
-> BranchB -> Y
-> Branch2 -> Z -> MergePoint(M)
Expected:
- C2 is scoped to C1 (it's in C1's branch)
- X and Y are scoped to C2 (inner conditional handles its own branches)
- Z is scoped to C1
- M might or might not be scoped depending on structure
"""
block_x = _make_task_block("X")
block_y = _make_task_block("Y")
block_z = _make_task_block("Z", next_block_label="M")
block_m = _make_task_block("M")
inner_cond = _make_conditional_block("C2", [("X", False), ("Y", True)])
outer_cond = _make_conditional_block("C1", [("C2", False), ("Z", True)])
label_to_block = {
"C1": outer_cond,
"C2": inner_cond,
"X": block_x,
"Y": block_y,
"Z": block_z,
"M": block_m,
}
default_next_map = {
"C1": None,
"C2": None, # Inner conditional doesn't have a default next
"X": None,
"Y": None,
"Z": "M",
"M": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
# C2 is scoped to C1 (it's in C1's branch, and tracing stops at C2)
assert scopes.get("C2") == "C1"
# X and Y are scoped to C2 (inner conditional)
assert scopes.get("X") == "C2"
assert scopes.get("Y") == "C2"
# Z is scoped to C1
assert scopes.get("Z") == "C1"
def test_no_conditionals_in_workflow(self):
"""Test workflow with no conditional blocks.
Workflow:
A -> B -> C
Expected: No scopes.
"""
block_a = _make_task_block("A", next_block_label="B")
block_b = _make_task_block("B", next_block_label="C")
block_c = _make_task_block("C")
label_to_block = {
"A": block_a,
"B": block_b,
"C": block_c,
}
default_next_map = {
"A": "B",
"B": "C",
"C": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
assert scopes == {}
def test_conditional_with_single_branch(self):
"""Test conditional with effectively one unique branch target.
Workflow:
Conditional(C) -> Branch1 -> A
-> Branch2 -> A (same target, deduplicated)
After deduplication, num_branches = 1, and A appears in 1/1 chains,
making it a "merge point" (appears in all branches).
"""
block_a = _make_task_block("A")
cond = _make_conditional_block("C", [("A", False), ("A", True)])
label_to_block = {
"C": cond,
"A": block_a,
}
default_next_map = {
"C": None,
"A": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
# A appears in all (1) branch chains, so it's treated as a merge point
assert scopes == {}
def test_three_branch_conditional_partial_merge(self):
"""Test three branches where only some merge.
Workflow:
Conditional(C) -> Branch1 -> A -> M
-> Branch2 -> B -> M
-> Branch3 -> D (terminal, no merge)
M appears in 2/3 branches, so it's NOT a merge point.
All of A, B, D, M should be scoped to C.
"""
block_a = _make_task_block("A", next_block_label="M")
block_b = _make_task_block("B", next_block_label="M")
block_d = _make_task_block("D")
block_m = _make_task_block("M")
cond = _make_conditional_block("C", [("A", False), ("B", False), ("D", True)])
label_to_block = {
"C": cond,
"A": block_a,
"B": block_b,
"D": block_d,
"M": block_m,
}
default_next_map = {
"C": None,
"A": "M",
"B": "M",
"D": None,
"M": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
# M only appears in 2/3 branches, so it's still inside the conditional scope
assert scopes == {"A": "C", "B": "C", "D": "C", "M": "C"}
def test_merge_point_with_blocks_after(self):
"""Test that blocks after the merge point are not scoped.
Workflow:
Conditional(C) -> Branch1 -> A -> M -> X -> Y
-> Branch2 -> B -> M
M is the merge point (appears in both chains).
X and Y come after M and should NOT be scoped.
"""
block_a = _make_task_block("A", next_block_label="M")
block_b = _make_task_block("B", next_block_label="M")
block_m = _make_task_block("M", next_block_label="X")
block_x = _make_task_block("X", next_block_label="Y")
block_y = _make_task_block("Y")
cond = _make_conditional_block("C", [("A", False), ("B", True)])
label_to_block = {
"C": cond,
"A": block_a,
"B": block_b,
"M": block_m,
"X": block_x,
"Y": block_y,
}
default_next_map = {
"C": None,
"A": "M",
"B": "M",
"M": "X",
"X": "Y",
"Y": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
# A and B are scoped, M and everything after is NOT
assert scopes == {"A": "C", "B": "C"}
assert "M" not in scopes
assert "X" not in scopes
assert "Y" not in scopes
def test_branch_to_nonexistent_block(self):
"""Test graceful handling when branch targets a non-existent block.
This shouldn't happen in practice (validation catches it), but the
function should handle it gracefully.
"""
cond = _make_conditional_block("C", [("MISSING", False), ("A", True)])
block_a = _make_task_block("A")
label_to_block = {
"C": cond,
"A": block_a,
}
default_next_map = {
"C": None,
"A": None,
}
# Should not raise, MISSING just won't be in the results
scopes = compute_conditional_scopes(label_to_block, default_next_map)
# Only A is scoped (MISSING is not in label_to_block)
assert scopes == {"A": "C"}
def test_empty_workflow(self):
"""Test with empty inputs."""
scopes = compute_conditional_scopes({}, {})
assert scopes == {}
def test_conditional_only_no_other_blocks(self):
"""Test with only a conditional block and no branch targets.
Workflow:
Conditional(C) -> Branch1 -> None
-> Branch2 -> None
"""
cond = _make_conditional_block("C", [(None, False), (None, True)])
label_to_block = {"C": cond}
default_next_map = {"C": None}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
assert scopes == {}
def test_asymmetric_branch_lengths(self):
"""Test branches with significantly different chain lengths.
Workflow:
Conditional(C) -> Branch1 -> A -> B -> C2 -> D -> M
-> Branch2 -> M
Branch1 has a long chain, Branch2 goes directly to M.
M is the only block in both chains, so it's the merge point.
"""
block_a = _make_task_block("A", next_block_label="B")
block_b = _make_task_block("B", next_block_label="C2")
block_c2 = _make_task_block("C2", next_block_label="D")
block_d = _make_task_block("D", next_block_label="M")
block_m = _make_task_block("M")
cond = _make_conditional_block("C", [("A", False), ("M", True)])
label_to_block = {
"C": cond,
"A": block_a,
"B": block_b,
"C2": block_c2,
"D": block_d,
"M": block_m,
}
default_next_map = {
"C": None,
"A": "B",
"B": "C2",
"C2": "D",
"D": "M",
"M": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
# A, B, C2, D are in Branch1 only, so they're scoped
# M appears in both branches, so it's the merge point
assert scopes == {"A": "C", "B": "C", "C2": "C", "D": "C"}
assert "M" not in scopes
def test_multiple_independent_conditionals(self):
"""Test multiple conditionals at the same level (not nested).
Workflow:
C1 -> Branch1 -> A
-> Branch2 -> B
(after C1) -> C2 -> Branch1 -> X
-> Branch2 -> Y
"""
block_a = _make_task_block("A", next_block_label="C2")
block_b = _make_task_block("B", next_block_label="C2")
block_x = _make_task_block("X")
block_y = _make_task_block("Y")
cond1 = _make_conditional_block("C1", [("A", False), ("B", True)])
cond2 = _make_conditional_block("C2", [("X", False), ("Y", True)])
label_to_block = {
"C1": cond1,
"C2": cond2,
"A": block_a,
"B": block_b,
"X": block_x,
"Y": block_y,
}
default_next_map = {
"C1": None,
"C2": None,
"A": "C2",
"B": "C2",
"X": None,
"Y": None,
}
scopes = compute_conditional_scopes(label_to_block, default_next_map)
# A and B are scoped to C1
# C2 is the merge point for C1 (appears in both A and B chains)
# X and Y are scoped to C2
assert scopes.get("A") == "C1"
assert scopes.get("B") == "C1"
assert "C2" not in scopes # C2 is a merge point for C1
assert scopes.get("X") == "C2"
assert scopes.get("Y") == "C2"

View File

@@ -0,0 +1,811 @@
"""
Tests for conditional block script caching support.
This test file verifies that:
1. Workflows with conditional blocks can have scripts generated for cacheable blocks
2. The regeneration logic doesn't trigger unnecessary regeneration for unexecuted branches
3. Progressive caching works correctly across multiple runs
4. Cached blocks from unexecuted branches are preserved during script regeneration (SKY-7815)
Key bugs this tests against:
- Previously, the regeneration check compared cached blocks against ALL blocks in the workflow
definition, causing "missing" blocks from unexecuted branches to trigger regeneration
on EVERY run, flooding the database with redundant script operations.
- (SKY-7815) When regeneration was triggered for a legitimate reason, cached blocks from
unexecuted conditional branches were DROPPED because generate_workflow_script_python_code()
only iterated blocks from the transform output (executed blocks). This caused a regeneration
loop where blocks kept getting dropped and re-added.
"""
from unittest.mock import AsyncMock, patch
import pytest
from skyvern.core.script_generations.generate_script import ScriptBlockSource, generate_workflow_script_python_code
from skyvern.forge.sdk.workflow.service import BLOCK_TYPES_THAT_SHOULD_BE_CACHED
from skyvern.schemas.workflows import BlockType
from skyvern.services.workflow_script_service import workflow_has_conditionals
class TestConditionalBlockDetection:
"""Tests for workflow_has_conditionals() function."""
def test_workflow_without_conditionals(self) -> None:
"""Workflows without conditional blocks should return False."""
class MockBlock:
def __init__(self, block_type: BlockType):
self.block_type = block_type
self.label = f"block_{block_type.value}"
class MockWorkflowDefinition:
def __init__(self, blocks: list):
self.blocks = blocks
class MockWorkflow:
def __init__(self, blocks: list):
self.workflow_definition = MockWorkflowDefinition(blocks)
self.workflow_id = "test_workflow"
# Workflow with only navigation and extraction blocks
blocks = [
MockBlock(BlockType.NAVIGATION),
MockBlock(BlockType.EXTRACTION),
]
workflow = MockWorkflow(blocks)
assert workflow_has_conditionals(workflow) is False
def test_workflow_with_conditionals(self) -> None:
"""Workflows with conditional blocks should return True."""
class MockBlock:
def __init__(self, block_type: BlockType):
self.block_type = block_type
self.label = f"block_{block_type.value}"
class MockWorkflowDefinition:
def __init__(self, blocks: list):
self.blocks = blocks
class MockWorkflow:
def __init__(self, blocks: list):
self.workflow_definition = MockWorkflowDefinition(blocks)
self.workflow_id = "test_workflow"
# Workflow with a conditional block
blocks = [
MockBlock(BlockType.NAVIGATION),
MockBlock(BlockType.CONDITIONAL),
MockBlock(BlockType.EXTRACTION),
]
workflow = MockWorkflow(blocks)
assert workflow_has_conditionals(workflow) is True
class TestConditionalBlockNotCached:
"""Tests verifying conditional blocks are not in BLOCK_TYPES_THAT_SHOULD_BE_CACHED."""
def test_conditional_not_in_cached_types(self) -> None:
"""Conditional blocks should NOT be in the set of cacheable block types."""
assert BlockType.CONDITIONAL not in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
def test_cacheable_types_exist(self) -> None:
"""Verify that cacheable block types exist and include expected types."""
assert BlockType.NAVIGATION in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
assert BlockType.EXTRACTION in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
assert BlockType.TASK in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
class TestRegenerationLogicForConditionals:
"""
Tests for the regeneration decision logic when conditionals are present.
The key fix: For workflows WITH conditionals, missing labels from unexecuted
branches should NOT trigger regeneration. This prevents the database flooding
bug where every run caused unnecessary script regeneration.
"""
def test_missing_labels_computation(self) -> None:
"""
Test that the missing labels computation works correctly.
For a workflow with branches A and B:
- should_cache_block_labels = {A, B, START}
- cached_block_labels = {A, START} (only A executed)
- missing_labels = {B}
Without the fix: missing_labels triggers regeneration every time
With the fix: missing_labels is ignored for workflows with conditionals
"""
# Simulate the computation
should_cache_block_labels = {"branch_a_extract", "branch_b_extract", "WORKFLOW_START_BLOCK"}
cached_block_labels = {"branch_a_extract", "WORKFLOW_START_BLOCK"}
missing_labels = should_cache_block_labels - cached_block_labels
assert missing_labels == {"branch_b_extract"}
# With conditionals, this should NOT trigger regeneration
has_conditionals = True
blocks_to_update: set[str] = set()
if missing_labels and not has_conditionals:
blocks_to_update.update(missing_labels)
# blocks_to_update should be empty because we have conditionals
assert len(blocks_to_update) == 0
def test_regeneration_triggered_without_conditionals(self) -> None:
"""
Without conditionals, missing labels SHOULD trigger regeneration.
This is the expected behavior for regular workflows where all blocks
should eventually be cached.
"""
should_cache_block_labels = {"block_1", "block_2", "WORKFLOW_START_BLOCK"}
cached_block_labels = {"block_1", "WORKFLOW_START_BLOCK"}
missing_labels = should_cache_block_labels - cached_block_labels
assert missing_labels == {"block_2"}
# Without conditionals, this SHOULD trigger regeneration
has_conditionals = False
blocks_to_update: set[str] = set()
if missing_labels and not has_conditionals:
blocks_to_update.update(missing_labels)
# blocks_to_update should contain missing labels
assert "block_2" in blocks_to_update
def test_explicit_updates_still_work_with_conditionals(self) -> None:
"""
Even with conditionals, explicit blocks_to_update from the caller
should still trigger regeneration.
This ensures that actual changes to executed blocks are still processed.
"""
blocks_to_update: set[str] = {"explicitly_updated_block"} # From caller
# Even with conditionals, explicit updates should trigger regeneration
should_regenerate = bool(blocks_to_update)
assert should_regenerate is True
class TestProgressiveCachingConcept:
"""
Tests documenting the progressive caching concept for conditional workflows.
Progressive caching means:
1. Run 1 takes branch A → caches blocks from A
2. Run 2 takes branch B → caches blocks from B (preserves A's cache)
3. Eventually all branches have cached blocks
The key insight is that we DON'T regenerate just because some branches
haven't executed yet.
"""
def test_progressive_caching_scenario(self) -> None:
"""
Simulate multiple runs with different branches.
Run 1: Branch A executes
Run 2: Branch A executes (should NOT regenerate - same blocks)
Run 3: Branch B executes (should cache B, preserve A)
"""
# Initial state
cached_blocks: set[str] = set()
# Run 1: Branch A executes
executed_blocks_run1 = {"nav_block", "branch_a_extract"}
cached_blocks.update(executed_blocks_run1)
assert cached_blocks == {"nav_block", "branch_a_extract"}
# Run 2: Branch A executes again
executed_blocks_run2 = {"nav_block", "branch_a_extract"}
# No new blocks to cache - should NOT trigger regeneration
new_blocks_run2 = executed_blocks_run2 - cached_blocks
assert len(new_blocks_run2) == 0 # Nothing new to cache
# Run 3: Branch B executes
executed_blocks_run3 = {"nav_block", "branch_b_extract"}
new_blocks_run3 = executed_blocks_run3 - cached_blocks
assert new_blocks_run3 == {"branch_b_extract"} # New block to cache
# Cache should now have both branches
cached_blocks.update(executed_blocks_run3)
assert cached_blocks == {"nav_block", "branch_a_extract", "branch_b_extract"}
class TestConditionalBlockCodeGeneration:
"""Tests for conditional block handling in code generation."""
def test_conditional_block_type_string(self) -> None:
"""Verify the conditional block type string matches expected value."""
assert BlockType.CONDITIONAL.value == "conditional"
# ---------------------------------------------------------------------------
# SKY-7815: Tests for cached block preservation during regeneration
# ---------------------------------------------------------------------------
class TestCachedBlockPreservationDuringRegeneration:
"""
Tests verifying that cached blocks from unexecuted conditional branches
are preserved when generate_workflow_script_python_code() regenerates a script.
Bug (SKY-7815):
When a workflow has conditional branches A and B:
- Run 1 executes branch A → script has blocks from A
- Run 2 executes branch B → regeneration triggered → transform only returns B's blocks
- generate_workflow_script_python_code() only iterates transform output (B's blocks)
- Cached blocks from A are loaded into cached_blocks dict but NEVER iterated
- Result: A's blocks are DROPPED from the new script → regeneration loop
Fix: After processing all blocks from the transform output, iterate remaining
cached_blocks entries and preserve them in both the DB and script output.
"""
@pytest.mark.asyncio
async def test_cached_blocks_from_unexecuted_branch_are_preserved(self) -> None:
"""
Core test: when only branch B's blocks are in the transform output,
branch A's cached blocks should still appear in the generated script.
"""
# Branch A's cached block (from a previous run)
branch_a_code = (
"async def branch_a_extract(page: SkyvernPage, context: RunContext) -> None:\n"
" await skyvern.extract(page, \"//div[@id='result']\")\n"
)
cached_blocks = {
"branch_a_extract": ScriptBlockSource(
label="branch_a_extract",
code=branch_a_code,
run_signature="await branch_a_extract(page, context)",
workflow_run_id="wr_run1",
workflow_run_block_id="wfrb_a",
input_fields=None,
),
}
# Transform output only has branch B's block (branch B executed this run)
blocks = [
{
"block_type": "navigation",
"label": "branch_b_navigate",
"task_id": "task_b",
"navigation_goal": "Go to page B",
"url": "https://example.com/b",
"workflow_run_id": "wr_run2",
"workflow_run_block_id": "wfrb_b",
},
]
actions_by_task = {
"task_b": [
{
"action_type": "click",
"action_id": "action_b1",
"xpath": "//button[@id='submit']",
"element_id": "submit",
"reasoning": "Click submit",
"intention": "Submit the form",
"confidence_float": 0.95,
"has_mini_agent": False,
},
],
}
workflow = {
"workflow_id": "wf_test",
"workflow_permanent_id": "wpid_test",
"title": "Test Conditional Workflow",
"workflow_definition": {
"parameters": [
{"parameter_type": "workflow", "key": "url", "default_value": "https://example.com"},
],
},
}
workflow_run_request = {
"workflow_id": "wpid_test",
"parameters": {"url": "https://example.com"},
}
with (
patch(
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
new_callable=AsyncMock,
return_value=("", {}),
),
patch(
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
new_callable=AsyncMock,
) as mock_create_block,
):
result = await generate_workflow_script_python_code(
file_name="test.py",
workflow_run_request=workflow_run_request,
workflow=workflow,
blocks=blocks,
actions_by_task=actions_by_task,
cached_blocks=cached_blocks,
updated_block_labels={"branch_b_navigate", "__start_block__"},
script_id="script_123",
script_revision_id="rev_123",
organization_id="org_123",
)
# The output should contain branch A's cached code
assert "branch_a_extract" in result, (
"Cached block from unexecuted branch A should be preserved in the script output"
)
# Verify create_or_update_script_block was called for the preserved block
preserved_calls = [
call
for call in mock_create_block.call_args_list
if call.kwargs.get("block_label") == "branch_a_extract"
]
assert len(preserved_calls) == 1, (
"create_or_update_script_block should be called for the preserved cached block"
)
preserved_call = preserved_calls[0]
assert preserved_call.kwargs["run_signature"] == "await branch_a_extract(page, context)"
assert preserved_call.kwargs["workflow_run_id"] == "wr_run1"
@pytest.mark.asyncio
async def test_cached_blocks_without_run_signature_are_not_preserved(self) -> None:
"""Cached blocks without a run_signature should NOT be preserved."""
cached_blocks = {
"incomplete_block": ScriptBlockSource(
label="incomplete_block",
code="async def incomplete_block(): pass\n",
run_signature=None, # No run_signature
workflow_run_id="wr_old",
workflow_run_block_id="wfrb_old",
input_fields=None,
),
}
blocks: list = []
actions_by_task: dict = {}
workflow = {
"workflow_id": "wf_test",
"title": "Test",
"workflow_definition": {"parameters": []},
}
with (
patch(
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
new_callable=AsyncMock,
return_value=("", {}),
),
patch(
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
new_callable=AsyncMock,
) as mock_create_block,
):
result = await generate_workflow_script_python_code(
file_name="test.py",
workflow_run_request={"workflow_id": "wpid_test"},
workflow=workflow,
blocks=blocks,
actions_by_task=actions_by_task,
cached_blocks=cached_blocks,
updated_block_labels={"__start_block__"},
script_id="script_123",
script_revision_id="rev_123",
organization_id="org_123",
)
# Incomplete block should NOT appear in the output
assert "incomplete_block" not in result
# create_or_update_script_block should NOT be called for incomplete block
incomplete_calls = [
call
for call in mock_create_block.call_args_list
if call.kwargs.get("block_label") == "incomplete_block"
]
assert len(incomplete_calls) == 0
@pytest.mark.asyncio
async def test_cached_blocks_without_code_are_not_preserved(self) -> None:
"""Cached blocks without code should NOT be preserved."""
cached_blocks = {
"empty_block": ScriptBlockSource(
label="empty_block",
code="", # Empty code
run_signature="await empty_block(page, context)",
workflow_run_id="wr_old",
workflow_run_block_id="wfrb_old",
input_fields=None,
),
}
with (
patch(
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
new_callable=AsyncMock,
return_value=("", {}),
),
patch(
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
new_callable=AsyncMock,
) as mock_create_block,
):
await generate_workflow_script_python_code(
file_name="test.py",
workflow_run_request={"workflow_id": "wpid_test"},
workflow={
"workflow_id": "wf_test",
"title": "Test",
"workflow_definition": {"parameters": []},
},
blocks=[],
actions_by_task={},
cached_blocks=cached_blocks,
updated_block_labels={"__start_block__"},
script_id="script_123",
script_revision_id="rev_123",
organization_id="org_123",
)
# Empty block should NOT appear
empty_calls = [
call for call in mock_create_block.call_args_list if call.kwargs.get("block_label") == "empty_block"
]
assert len(empty_calls) == 0
@pytest.mark.asyncio
async def test_already_processed_blocks_are_not_duplicated(self) -> None:
"""
Blocks that appear in both the transform output AND cached_blocks
should NOT be duplicated. The transform output processing handles them.
"""
block_code = (
"async def shared_block(page: SkyvernPage, context: RunContext) -> None:\n"
' await skyvern.click(page, "//button")\n'
)
cached_blocks = {
"shared_block": ScriptBlockSource(
label="shared_block",
code=block_code,
run_signature="await shared_block(page, context)",
workflow_run_id="wr_run1",
workflow_run_block_id="wfrb_shared",
input_fields=None,
),
}
# Same block also appears in the transform output (it executed this run too)
blocks = [
{
"block_type": "navigation",
"label": "shared_block",
"task_id": "task_shared",
"navigation_goal": "Navigate somewhere",
"url": "https://example.com",
"workflow_run_id": "wr_run2",
"workflow_run_block_id": "wfrb_shared_run2",
},
]
actions_by_task = {
"task_shared": [
{
"action_type": "click",
"action_id": "action_1",
"xpath": "//button",
"element_id": "btn",
"reasoning": "Click",
"intention": "Click",
"confidence_float": 0.9,
"has_mini_agent": False,
},
],
}
with (
patch(
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
new_callable=AsyncMock,
return_value=("", {}),
),
patch(
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
new_callable=AsyncMock,
) as mock_create_block,
):
await generate_workflow_script_python_code(
file_name="test.py",
workflow_run_request={"workflow_id": "wpid_test"},
workflow={
"workflow_id": "wf_test",
"title": "Test",
"workflow_definition": {"parameters": []},
},
blocks=blocks,
actions_by_task=actions_by_task,
cached_blocks=cached_blocks,
updated_block_labels={"shared_block", "__start_block__"},
script_id="script_123",
script_revision_id="rev_123",
organization_id="org_123",
)
# The block should appear exactly once (from the transform output processing,
# NOT duplicated by the preservation loop)
shared_calls = [
call for call in mock_create_block.call_args_list if call.kwargs.get("block_label") == "shared_block"
]
# Should be called once from the normal task_v1 processing, NOT again from preservation
assert len(shared_calls) == 1
@pytest.mark.asyncio
async def test_multiple_unexecuted_branches_all_preserved(self) -> None:
"""
When a workflow has 3 conditional branches and only 1 executes,
cached blocks from the other 2 branches should ALL be preserved.
"""
def _make_cached_block(label: str) -> ScriptBlockSource:
return ScriptBlockSource(
label=label,
code=f"async def {label}(page: SkyvernPage, context: RunContext) -> None:\n pass\n",
run_signature=f"await {label}(page, context)",
workflow_run_id="wr_old",
workflow_run_block_id=f"wfrb_{label}",
input_fields=None,
)
cached_blocks = {
"branch_a_extract": _make_cached_block("branch_a_extract"),
"branch_b_navigate": _make_cached_block("branch_b_navigate"),
# branch_c executed this run, so it's also in blocks below
}
# Only branch C's block is in the transform output
blocks = [
{
"block_type": "extraction",
"label": "branch_c_extract",
"task_id": "task_c",
"data_extraction_goal": "Extract C data",
"workflow_run_id": "wr_run3",
"workflow_run_block_id": "wfrb_c",
},
]
actions_by_task = {
"task_c": [
{
"action_type": "extract",
"action_id": "action_c1",
"xpath": "//div[@class='data']",
"element_id": "data",
"reasoning": "Extract",
"intention": "Extract data",
"confidence_float": 0.9,
"has_mini_agent": False,
"data_extraction_goal": "Extract C data",
},
],
}
with (
patch(
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
new_callable=AsyncMock,
return_value=("", {}),
),
patch(
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
new_callable=AsyncMock,
) as mock_create_block,
):
result = await generate_workflow_script_python_code(
file_name="test.py",
workflow_run_request={"workflow_id": "wpid_test"},
workflow={
"workflow_id": "wf_test",
"title": "Test",
"workflow_definition": {"parameters": []},
},
blocks=blocks,
actions_by_task=actions_by_task,
cached_blocks=cached_blocks,
updated_block_labels={"branch_c_extract", "__start_block__"},
script_id="script_123",
script_revision_id="rev_123",
organization_id="org_123",
)
# Both branch A and B should be preserved
assert "branch_a_extract" in result, "Branch A cached block should be preserved"
assert "branch_b_navigate" in result, "Branch B cached block should be preserved"
assert "branch_c_extract" in result, "Branch C (executed) block should be present"
# Verify DB entries were created for all 3 blocks + __start_block__
all_labels = {call.kwargs.get("block_label") for call in mock_create_block.call_args_list}
assert "branch_a_extract" in all_labels
assert "branch_b_navigate" in all_labels
assert "branch_c_extract" in all_labels
assert "__start_block__" in all_labels
@pytest.mark.asyncio
async def test_preservation_without_script_context(self) -> None:
"""
When script_id/script_revision_id/organization_id are not provided,
cached blocks should still be added to the script output (just no DB calls).
"""
branch_a_code = "async def branch_a(page: SkyvernPage, context: RunContext) -> None:\n pass\n"
cached_blocks = {
"branch_a": ScriptBlockSource(
label="branch_a",
code=branch_a_code,
run_signature="await branch_a(page, context)",
workflow_run_id="wr_old",
workflow_run_block_id="wfrb_a",
input_fields=None,
),
}
with (
patch(
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
new_callable=AsyncMock,
return_value=("", {}),
),
patch(
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
new_callable=AsyncMock,
) as mock_create_block,
):
result = await generate_workflow_script_python_code(
file_name="test.py",
workflow_run_request={"workflow_id": "wpid_test"},
workflow={
"workflow_id": "wf_test",
"title": "Test",
"workflow_definition": {"parameters": []},
},
blocks=[],
actions_by_task={},
cached_blocks=cached_blocks,
updated_block_labels={"__start_block__"},
# No script context
script_id=None,
script_revision_id=None,
organization_id=None,
)
# Code should still be in the output
assert "branch_a" in result
# But no DB calls should be made for preserved blocks
preserved_calls = [
call for call in mock_create_block.call_args_list if call.kwargs.get("block_label") == "branch_a"
]
assert len(preserved_calls) == 0
class TestRegenerationLoopPrevention:
"""
End-to-end tests for the regeneration loop prevention (SKY-7815).
The regeneration loop happens when:
1. Workflow has conditional branches A and B
2. Run 1 caches branch A → script has A's blocks
3. Run 2 executes branch B → triggers regeneration for B
4. During regeneration, transform only returns B's blocks
5. A's cached blocks are dropped from the new script
6. Run 3 executes branch A → A is "missing" → triggers regeneration
7. During regeneration, B's cached blocks are dropped → loop continues
The fix has two parts:
1. generate_script_if_needed: Don't add missing labels for conditional workflows
2. generate_workflow_script_python_code: Preserve cached blocks from unexecuted branches
"""
def test_regeneration_loop_scenario_is_prevented(self) -> None:
"""
Simulate the full regeneration loop scenario and verify it's prevented.
This test verifies both parts of the fix working together:
- Missing labels don't trigger regeneration for conditional workflows
- Even if regeneration IS triggered (for other reasons), cached blocks are preserved
"""
# --- Part 1: generate_script_if_needed logic ---
# Workflow definition has blocks: nav, branch_a_extract, branch_b_extract
should_cache_block_labels = {"nav_block", "branch_a_extract", "branch_b_extract", "__start_block__"}
# After Run 1: only nav and branch_a are cached
cached_block_labels = {"nav_block", "branch_a_extract", "__start_block__"}
missing_labels = should_cache_block_labels - cached_block_labels
assert missing_labels == {"branch_b_extract"}
has_conditionals = True
blocks_to_update: set[str] = set()
# With conditionals, missing labels should NOT be added
if missing_labels and not has_conditionals:
blocks_to_update.update(missing_labels)
elif missing_labels and has_conditionals:
pass # Skip - expected for conditional workflows
# No regeneration needed just because of missing labels
assert len(blocks_to_update) == 0
# --- Part 2: Even if regeneration IS triggered ---
# e.g., branch B executed this run and needs caching
blocks_to_update.add("branch_b_extract")
# The transform output only has branch B's block
transform_output_labels = {"nav_block", "branch_b_extract"}
# cached_blocks from old script has branch A's data
old_cached_block_labels = {"nav_block", "branch_a_extract"}
# After the fix, the preservation loop handles blocks NOT in transform output
processed_by_transform = transform_output_labels
preserved_from_cache = old_cached_block_labels - processed_by_transform
assert preserved_from_cache == {"branch_a_extract"}, (
"Branch A's block should be preserved even though it wasn't in the transform output"
)
# Final result should have ALL blocks
final_blocks = transform_output_labels | preserved_from_cache
assert final_blocks == {"nav_block", "branch_a_extract", "branch_b_extract"}
def test_no_regeneration_loop_across_three_runs(self) -> None:
"""
Simulate 3 runs and verify no regeneration loop occurs.
Run 1: Branch A → cache A
Run 2: Branch B → regenerate (B is new) → A is preserved
Run 3: Branch A → no regeneration needed (A is still cached)
"""
# --- Run 1: Branch A executes ---
cached_blocks_after_run1 = {"nav_block", "branch_a_extract", "__start_block__"}
# --- Run 2: Branch B executes ---
has_conditionals = True
should_cache = {"nav_block", "branch_a_extract", "branch_b_extract", "__start_block__"}
missing_run2 = should_cache - cached_blocks_after_run1
assert missing_run2 == {"branch_b_extract"}
blocks_to_update_run2: set[str] = set()
# Missing labels NOT added for conditional workflows
if missing_run2 and not has_conditionals:
blocks_to_update_run2.update(missing_run2)
# branch_b_extract is added because it actually executed
blocks_to_update_run2.add("branch_b_extract")
# Regeneration happens, but branch A is PRESERVED
transform_output_run2 = {"nav_block", "branch_b_extract"}
preserved_run2 = {"branch_a_extract"} # From cache, not in transform
cached_blocks_after_run2 = transform_output_run2 | preserved_run2 | {"__start_block__"}
assert cached_blocks_after_run2 == {"nav_block", "branch_a_extract", "branch_b_extract", "__start_block__"}
# --- Run 3: Branch A executes again ---
missing_run3 = should_cache - cached_blocks_after_run2
assert len(missing_run3) == 0, "No missing blocks after Run 2 because branch A was preserved"
blocks_to_update_run3: set[str] = set()
if missing_run3 and not has_conditionals:
blocks_to_update_run3.update(missing_run3)
# branch_a_extract already has cached code, so it's NOT added to blocks_to_update
# (execution tracking only adds blocks that DON'T have cached code)
should_regenerate_run3 = bool(blocks_to_update_run3)
assert should_regenerate_run3 is False, "No regeneration needed on Run 3 - the loop is broken"

View File

@@ -0,0 +1,75 @@
import pytest
from skyvern.forge.sdk.api.custom_credential_client import CustomCredentialAPIClient
from skyvern.forge.sdk.schemas.credentials import CredentialType, SecretCredential
@pytest.fixture
def client() -> CustomCredentialAPIClient:
return CustomCredentialAPIClient(api_base_url="https://custom.example.com", api_token="token-123")
def test_credential_to_api_payload_with_label(client: CustomCredentialAPIClient) -> None:
credential = SecretCredential(secret_value="sk-secret", secret_label="api-key")
payload = client._credential_to_api_payload(credential)
assert payload == {
"type": "secret",
"secret_value": "sk-secret",
"secret_label": "api-key",
}
def test_credential_to_api_payload_without_label(client: CustomCredentialAPIClient) -> None:
credential = SecretCredential(secret_value="sk-secret-no-label")
payload = client._credential_to_api_payload(credential)
assert payload == {
"type": "secret",
"secret_value": "sk-secret-no-label",
}
def test_api_response_to_credential_secret_with_label(client: CustomCredentialAPIClient) -> None:
response = {
"type": "secret",
"secret_value": "shhh",
"secret_label": "prod-api",
}
credential_item = client._api_response_to_credential(response, name="Prod API", item_id="cred_123")
assert credential_item.item_id == "cred_123"
assert credential_item.name == "Prod API"
assert credential_item.credential_type == CredentialType.SECRET
assert isinstance(credential_item.credential, SecretCredential)
assert credential_item.credential.secret_value == "shhh"
assert credential_item.credential.secret_label == "prod-api"
def test_api_response_to_credential_secret_without_label(client: CustomCredentialAPIClient) -> None:
response = {
"type": "secret",
"secret_value": "token-only",
}
credential_item = client._api_response_to_credential(response, name="Token", item_id="cred_456")
assert credential_item.item_id == "cred_456"
assert credential_item.name == "Token"
assert credential_item.credential_type == CredentialType.SECRET
assert isinstance(credential_item.credential, SecretCredential)
assert credential_item.credential.secret_value == "token-only"
assert credential_item.credential.secret_label is None
def test_api_response_to_credential_secret_missing_required_field(client: CustomCredentialAPIClient) -> None:
response = {
"type": "secret",
"secret_label": "no-secret-value",
}
with pytest.raises(ValueError, match="Missing required secret fields from API"):
client._api_response_to_credential(response, name="Broken Secret", item_id="cred_789")

View File

@@ -0,0 +1,391 @@
import os
import tempfile
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.forge.sdk.models import StepStatus
from skyvern.webeye.actions.actions import DownloadFileAction
from skyvern.webeye.actions.handler import handle_download_file_action
from skyvern.webeye.actions.responses import ActionFailure, ActionSuccess
from skyvern.webeye.scraper.scraped_page import ScrapedPage
from tests.unit.helpers import make_organization, make_step, make_task
@pytest.mark.asyncio
async def test_handle_download_file_action_with_byte_data() -> None:
"""Test that when byte data is provided, the file should be saved directly"""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
# Create mock objects
page = MagicMock()
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
# Create test byte data
test_bytes = b"test file content"
action = DownloadFileAction(
file_name="test_file.txt",
byte=test_bytes,
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
# Mock initialize_download_dir to return a temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=temp_dir):
result = await handle_download_file_action(action, page, scraped_page, task, step)
# Verify result (download_triggered is set by outer handle action flow when in context)
assert len(result) == 1
assert isinstance(result[0], ActionSuccess)
# Verify file was created
expected_file_path = os.path.join(temp_dir, "test_file.txt")
assert os.path.exists(expected_file_path)
# Verify file content
with open(expected_file_path, "rb") as f:
assert f.read() == test_bytes
@pytest.mark.asyncio
async def test_handle_download_file_action_with_download_url() -> None:
"""Test that when download_url is provided, page.goto is called and returns ActionSuccess"""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
# Create mock objects
page = MagicMock()
page.goto = AsyncMock(return_value=None)
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
action = DownloadFileAction(
file_name="downloaded_file.pdf",
download_url="https://example.com/file.pdf",
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value="/tmp"):
result = await handle_download_file_action(action, page, scraped_page, task, step)
# Verify page.goto was called with the correct URL (handler uses browser navigation for download_url)
page.goto.assert_called_once()
assert page.goto.call_args[0][0] == "https://example.com/file.pdf"
# Verify result
assert len(result) == 1
assert isinstance(result[0], ActionSuccess)
@pytest.mark.asyncio
async def test_handle_download_file_action_with_download_url_same_filename() -> None:
"""Test that when download_url is provided, page.goto is called with the URL and returns ActionSuccess"""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
# Create mock objects
page = MagicMock()
page.goto = AsyncMock(return_value=None)
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
action = DownloadFileAction(
file_name="same_name.pdf",
download_url="https://example.com/file.pdf",
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value="/tmp"):
result = await handle_download_file_action(action, page, scraped_page, task, step)
page.goto.assert_called_once()
assert page.goto.call_args[0][0] == "https://example.com/file.pdf"
assert len(result) == 1
assert isinstance(result[0], ActionSuccess)
@pytest.mark.asyncio
async def test_handle_download_file_action_without_byte_or_url() -> None:
"""Test that when neither byte data nor download_url is provided, should return ActionSuccess (no download triggered)."""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
# Create mock objects
page = MagicMock()
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
action = DownloadFileAction(
file_name="test_file.txt",
byte=None,
download_url=None,
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
with tempfile.TemporaryDirectory() as temp_dir:
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=temp_dir):
result = await handle_download_file_action(action, page, scraped_page, task, step)
# Verify result (download_triggered is set by outer handle action flow when in context)
assert len(result) == 1
assert isinstance(result[0], ActionSuccess)
@pytest.mark.asyncio
async def test_handle_download_file_action_with_byte_priority() -> None:
"""Test that when both byte and download_url are provided, byte data should take priority"""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
# Create mock objects
page = MagicMock()
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
# Create test byte data
test_bytes = b"byte data content"
action = DownloadFileAction(
file_name="test_file.txt",
byte=test_bytes,
download_url="https://example.com/file.pdf",
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
page.goto = AsyncMock(return_value=None)
with tempfile.TemporaryDirectory() as temp_dir:
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=temp_dir):
result = await handle_download_file_action(action, page, scraped_page, task, step)
# Byte data takes priority: page.goto should not be called
page.goto.assert_not_called()
assert len(result) == 1
assert isinstance(result[0], ActionSuccess)
expected_file_path = os.path.join(temp_dir, "test_file.txt")
assert os.path.exists(expected_file_path)
with open(expected_file_path, "rb") as f:
assert f.read() == test_bytes
@pytest.mark.asyncio
async def test_handle_download_file_action_with_file_name_empty() -> None:
"""Test that when file_name is empty string, UUID should be used as filename"""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
# Create mock objects
page = MagicMock()
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
test_bytes = b"test content"
action = DownloadFileAction(
file_name="", # Empty string, handler will use UUID
byte=test_bytes,
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
with tempfile.TemporaryDirectory() as temp_dir:
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=temp_dir):
result = await handle_download_file_action(action, page, scraped_page, task, step)
# Verify result (download_triggered is set by outer handle action flow when in context)
assert len(result) == 1
assert isinstance(result[0], ActionSuccess)
# Verify file was created (filename should be UUID)
files = os.listdir(temp_dir)
assert len(files) == 1
# Verify file content
file_path = os.path.join(temp_dir, files[0])
with open(file_path, "rb") as f:
assert f.read() == test_bytes
@pytest.mark.asyncio
async def test_handle_download_file_action_download_url_error() -> None:
"""Test that when download_url download fails, should return ActionFailure"""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
# Create mock objects
page = MagicMock()
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
action = DownloadFileAction(
file_name="test_file.txt",
download_url="https://example.com/file.pdf",
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
page.goto = AsyncMock(side_effect=Exception("Download failed"))
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value="/tmp"):
result = await handle_download_file_action(action, page, scraped_page, task, step)
assert len(result) == 1
assert isinstance(result[0], ActionFailure)
assert result[0].exception_type == "Exception"
assert result[0].exception_message == "Download failed"
@pytest.mark.asyncio
async def test_handle_download_file_action_file_write_error() -> None:
"""Test that when file write fails, should return ActionFailure"""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
# Create mock objects
page = MagicMock()
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
test_bytes = b"test content"
action = DownloadFileAction(
file_name="test_file.txt",
byte=test_bytes,
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
# Mock initialize_download_dir to return an invalid path (e.g., read-only directory)
with tempfile.TemporaryDirectory() as temp_dir:
# Create a read-only directory to simulate write failure
read_only_dir = os.path.join(temp_dir, "readonly")
os.makedirs(read_only_dir, mode=0o555)
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=read_only_dir):
result = await handle_download_file_action(action, page, scraped_page, task, step)
# Verify result should be ActionFailure
assert len(result) == 1
assert isinstance(result[0], ActionFailure)
@pytest.mark.asyncio
async def test_handle_download_file_action_download_url_err_aborted_swallowed() -> None:
"""Test that when page.goto raises net::ERR_ABORTED (browser download flow), error is swallowed and returns ActionSuccess"""
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
page = MagicMock()
page.goto = AsyncMock(side_effect=Exception("net::ERR_ABORTED at https://example.com/file.pdf"))
browser_state = MagicMock()
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=AsyncMock(return_value=[]),
_scrape_exclude=None,
)
action = DownloadFileAction(
file_name="test_file.txt",
download_url="https://example.com/file.pdf",
organization_id=task.organization_id,
task_id=task.task_id,
step_id=step.step_id,
)
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value="/tmp"):
result = await handle_download_file_action(action, page, scraped_page, task, step)
assert len(result) == 1
assert isinstance(result[0], ActionSuccess)

View File

@@ -0,0 +1,183 @@
"""Tests for DAG validation when blocks reference the finally block.
The finally block is excluded from the DAG before validation. Any block whose
next_block_label points to the finally block must have that edge nullified so
_build_workflow_graph does not raise InvalidWorkflowDefinition for a missing label.
"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from skyvern.forge.sdk.workflow.exceptions import InvalidWorkflowDefinition
from skyvern.forge.sdk.workflow.models.block import (
BranchCondition,
ConditionalBlock,
HttpRequestBlock,
TaskBlock,
)
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
from skyvern.forge.sdk.workflow.service import WorkflowService
def _make_output_parameter(key: str) -> OutputParameter:
now = datetime.now(tz=timezone.utc)
return OutputParameter(
key=key,
parameter_type="output",
output_parameter_id=f"op_{key}",
workflow_id="wf_test",
created_at=now,
modified_at=now,
)
def _make_task_block(label: str, *, next_block_label: str | None = None) -> TaskBlock:
return TaskBlock(
label=label,
url="https://example.com",
output_parameter=_make_output_parameter(label),
next_block_label=next_block_label,
)
def _make_http_block(label: str, *, next_block_label: str | None = None) -> HttpRequestBlock:
return HttpRequestBlock(
label=label,
url="https://example.com",
method="GET",
output_parameter=_make_output_parameter(label),
next_block_label=next_block_label,
)
class TestStripFinallyBlockReferences:
"""Tests for WorkflowService._strip_finally_block_references."""
def test_removes_finally_block_and_nullifies_edge(self):
block_1 = _make_task_block("block_1", next_block_label="block_2")
block_2 = _make_task_block("block_2", next_block_label="finally_block")
finally_block = _make_http_block("finally_block")
result = WorkflowService._strip_finally_block_references(
[block_1, block_2, finally_block],
"finally_block",
)
assert len(result) == 2
labels = [b.label for b in result]
assert "finally_block" not in labels
# block_2 should have its edge to finally_block nullified
assert result[1].label == "block_2"
assert result[1].next_block_label is None
def test_conditional_branch_pointing_to_finally_is_nullified(self):
block_1 = _make_task_block("block_1")
cond_block = ConditionalBlock(
label="cond_block",
output_parameter=_make_output_parameter("cond_block"),
branch_conditions=[
BranchCondition(next_block_label="block_1", is_default=True),
BranchCondition(
next_block_label="finally_block",
criteria={"criteria_type": "jinja2_template", "expression": "{{ true }}"},
),
],
)
finally_block = _make_http_block("finally_block")
result = WorkflowService._strip_finally_block_references(
[block_1, cond_block, finally_block],
"finally_block",
)
assert len(result) == 2
cond = next(b for b in result if b.label == "cond_block")
for branch in cond.branch_conditions:
assert branch.next_block_label != "finally_block", (
"Branch pointing to finally_block should have been nullified"
)
def test_noop_when_no_finally_block(self):
block_1 = _make_task_block("block_1", next_block_label="block_2")
block_2 = _make_task_block("block_2")
result = WorkflowService._strip_finally_block_references(
[block_1, block_2],
"nonexistent_finally",
)
assert len(result) == 2
assert result[0].next_block_label == "block_2"
class TestBuildWorkflowGraphWithFinallyBlock:
"""Tests that _build_workflow_graph succeeds after stripping finally block references."""
def test_dag_validation_with_block_pointing_to_finally_block(self):
block_1 = _make_task_block("block_1", next_block_label="block_2")
block_2 = _make_task_block("block_2", next_block_label="finally_block")
finally_block = _make_http_block("finally_block")
dag_blocks = WorkflowService._strip_finally_block_references(
[block_1, block_2, finally_block],
"finally_block",
)
svc = WorkflowService()
start_label, label_to_block, default_next_map = svc._build_workflow_graph(dag_blocks)
assert start_label == "block_1"
assert set(label_to_block.keys()) == {"block_1", "block_2"}
assert default_next_map["block_1"] == "block_2"
assert default_next_map["block_2"] is None
def test_dag_validation_with_conditional_block_branch_pointing_to_finally(self):
block_1 = _make_task_block("block_1")
cond_block = ConditionalBlock(
label="cond_block",
output_parameter=_make_output_parameter("cond_block"),
branch_conditions=[
BranchCondition(next_block_label="block_1", is_default=True),
BranchCondition(
next_block_label="finally_block",
criteria={"criteria_type": "jinja2_template", "expression": "{{ true }}"},
),
],
)
finally_block = _make_http_block("finally_block")
dag_blocks = WorkflowService._strip_finally_block_references(
[cond_block, block_1, finally_block],
"finally_block",
)
svc = WorkflowService()
start_label, label_to_block, default_next_map = svc._build_workflow_graph(dag_blocks)
assert start_label == "cond_block"
assert set(label_to_block.keys()) == {"cond_block", "block_1"}
def test_dag_validation_without_finally_block(self):
block_1 = _make_task_block("block_1", next_block_label="block_2")
block_2 = _make_task_block("block_2")
svc = WorkflowService()
start_label, label_to_block, default_next_map = svc._build_workflow_graph([block_1, block_2])
assert start_label == "block_1"
assert set(label_to_block.keys()) == {"block_1", "block_2"}
assert default_next_map["block_1"] == "block_2"
def test_dag_validation_fails_without_stripping_finally_block(self):
"""Without stripping, a block referencing the removed finally block causes an error."""
block_1 = _make_task_block("block_1", next_block_label="block_2")
block_2 = _make_task_block("block_2", next_block_label="finally_block")
# Manually exclude the finally block but do NOT nullify the edge
dag_blocks = [block_1, block_2]
svc = WorkflowService()
with pytest.raises(InvalidWorkflowDefinition, match="unknown next_block_label"):
svc._build_workflow_graph(dag_blocks)

View File

@@ -0,0 +1,412 @@
"""
Unit tests for ForLoop block support in cached scripts (SKY-7751).
These tests verify that ForLoop blocks are properly handled during:
1. Workflow transformation (transform_workflow_run.py)
2. Script generation (generate_script.py)
"""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import libcst as cst
import pytest
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_script import _build_for_loop_statement
from skyvern.core.script_generations.transform_workflow_run import (
CodeGenInput,
transform_workflow_run_to_code_gen_input,
)
from skyvern.forge.sdk.workflow.service import BLOCK_TYPES_THAT_SHOULD_BE_CACHED
from skyvern.schemas.workflows import BlockType
class TestForLoopInCacheableBlocks:
"""Test that ForLoop is included in cacheable block types."""
def test_forloop_in_block_types_that_should_be_cached(self) -> None:
"""Verify ForLoop is included in BLOCK_TYPES_THAT_SHOULD_BE_CACHED."""
assert BlockType.FOR_LOOP in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
class TestForLoopTransformation:
"""Test the transformation of ForLoop blocks during script generation."""
def test_forloop_child_blocks_identified_by_parent_id(self) -> None:
"""Test that child blocks inside ForLoop can be identified by parent_workflow_run_block_id."""
# Mock workflow run blocks
forloop_block = MagicMock()
forloop_block.workflow_run_block_id = "wfrb_forloop_123"
forloop_block.parent_workflow_run_block_id = None
forloop_block.block_type = BlockType.FOR_LOOP
forloop_block.label = "process_urls"
forloop_block.task_id = None
child_task_block = MagicMock()
child_task_block.workflow_run_block_id = "wfrb_child_456"
child_task_block.parent_workflow_run_block_id = "wfrb_forloop_123" # Points to ForLoop
child_task_block.block_type = "task"
child_task_block.label = "extract_data"
child_task_block.task_id = "task_789"
child_task_block.status = "completed"
child_task_block.output = {"extracted": "data"}
all_blocks = [forloop_block, child_task_block]
# Filter child blocks by parent_workflow_run_block_id
child_blocks = [b for b in all_blocks if b.parent_workflow_run_block_id == forloop_block.workflow_run_block_id]
assert len(child_blocks) == 1
assert child_blocks[0].label == "extract_data"
assert child_blocks[0].task_id == "task_789"
def test_child_run_blocks_by_label_mapping(self) -> None:
"""Test creation of child run blocks mapping by label."""
child_block_1 = MagicMock()
child_block_1.label = "extract_data"
child_block_1.block_type = "extraction"
child_block_1.task_id = "task_1"
child_block_2 = MagicMock()
child_block_2.label = "navigate_page"
child_block_2.block_type = "navigation"
child_block_2.task_id = "task_2"
child_run_blocks = [child_block_1, child_block_2]
# Create mapping by label
child_run_blocks_by_label = {b.label: b for b in child_run_blocks if b.label}
assert "extract_data" in child_run_blocks_by_label
assert "navigate_page" in child_run_blocks_by_label
assert child_run_blocks_by_label["extract_data"].task_id == "task_1"
def test_forloop_definition_block_has_loop_blocks(self) -> None:
"""Test that ForLoop definition block contains loop_blocks field."""
forloop_definition = {
"block_type": BlockType.FOR_LOOP,
"label": "process_urls",
"loop_variable_reference": "{{ urls }}",
"loop_blocks": [
{
"block_type": "extraction",
"label": "extract_data",
"data_extraction_goal": "Extract page content",
},
{
"block_type": "navigation",
"label": "navigate_next",
"navigation_goal": "Go to next page",
},
],
}
loop_blocks = forloop_definition.get("loop_blocks", [])
assert len(loop_blocks) == 2
assert loop_blocks[0]["label"] == "extract_data"
assert loop_blocks[1]["label"] == "navigate_next"
class TestForLoopScriptGeneration:
"""Test script code generation for ForLoop blocks."""
def test_build_for_loop_statement_signature(self) -> None:
"""Test that _build_for_loop_statement is called with correct parameters."""
forloop_block = {
"block_type": "for_loop",
"label": "process_items",
"loop_variable_reference": "{{ items }}",
"loop_blocks": [
{
"block_type": "extraction",
"label": "extract_item",
"data_extraction_goal": "Extract item details",
},
],
}
# This should not raise an error
result = _build_for_loop_statement("process_items", forloop_block)
# The result should be a CST For node
assert result is not None
assert hasattr(result, "target") # For loop has a target
assert hasattr(result, "iter") # For loop has an iterator
assert hasattr(result, "body") # For loop has a body
class TestForLoopChildBlockActions:
"""Test that actions from child blocks inside ForLoop are collected."""
def test_task_block_in_forloop_should_collect_actions(self) -> None:
"""Test that task blocks inside ForLoop have their actions collected."""
# This tests the logic added in transform_workflow_run.py
child_run_block = MagicMock()
child_run_block.block_type = "task"
child_run_block.task_id = "task_123"
child_run_block.label = "search_item"
# Verify that the child block type is in SCRIPT_TASK_BLOCKS
assert child_run_block.block_type in SCRIPT_TASK_BLOCKS
# Verify that task_id is present (required for action collection)
assert child_run_block.task_id is not None
def test_non_task_block_in_forloop_does_not_collect_actions(self) -> None:
"""Test that non-task blocks inside ForLoop don't collect actions."""
child_run_block = MagicMock()
child_run_block.block_type = "goto_url"
child_run_block.task_id = None
child_run_block.label = "go_to_url"
# Verify that goto_url is not in SCRIPT_TASK_BLOCKS
assert child_run_block.block_type not in SCRIPT_TASK_BLOCKS
class TestForLoopActionsHydration:
"""Test that actions from ForLoop child blocks are properly hydrated."""
def test_actions_by_task_includes_forloop_child_actions(self) -> None:
"""Test that actions_by_task dict includes actions from ForLoop child blocks."""
actions_by_task: dict[str, list[dict[str, Any]]] = {}
# Simulate adding actions from a child block inside ForLoop
child_task_id = "task_in_forloop_123"
child_actions = [
{
"action_type": "input_text",
"action_id": "action_1",
"text": "search query",
"xpath": "//input[@id='search']",
},
{
"action_type": "click",
"action_id": "action_2",
"xpath": "//button[@type='submit']",
},
]
actions_by_task[child_task_id] = child_actions
# Verify actions are stored
assert child_task_id in actions_by_task
assert len(actions_by_task[child_task_id]) == 2
assert actions_by_task[child_task_id][0]["action_type"] == "input_text"
@pytest.mark.asyncio
async def test_transform_forloop_block_integration() -> None:
"""
Integration test for ForLoop block transformation.
This test mocks the database calls and verifies that the transformation
correctly processes ForLoop blocks and their child blocks.
"""
# Create a mock CodeGenInput with ForLoop block
mock_input = CodeGenInput(
file_name="test_workflow.py",
workflow_run={"workflow_id": "wpid_123"},
workflow={"workflow_definition": {"blocks": []}},
workflow_blocks=[
{
"block_type": "for_loop",
"label": "process_urls",
"loop_variable_reference": "{{ urls }}",
"workflow_run_id": "wr_123",
"workflow_run_block_id": "wfrb_456",
"loop_blocks": [
{
"block_type": "extraction",
"label": "extract_data",
"data_extraction_goal": "Get page content",
"task_id": "task_789",
"status": "completed",
"output": {"content": "extracted data"},
}
],
}
],
actions_by_task={
"task_789": [
{
"action_type": "extract",
"action_id": "action_123",
"xpath": "//div[@id='content']",
}
]
},
task_v2_child_blocks={},
)
# Verify the structure
assert len(mock_input.workflow_blocks) == 1
assert mock_input.workflow_blocks[0]["block_type"] == "for_loop"
assert len(mock_input.workflow_blocks[0]["loop_blocks"]) == 1
assert mock_input.workflow_blocks[0]["loop_blocks"][0]["task_id"] == "task_789"
assert "task_789" in mock_input.actions_by_task
@pytest.mark.asyncio
async def test_transform_forloop_block_with_mocked_db() -> None:
"""
Full integration test for ForLoop block transformation with mocked database.
This test verifies the actual transformation logic in transform_workflow_run.py
correctly processes ForLoop blocks and their child blocks.
"""
# Mock workflow run response
mock_workflow_run_resp = MagicMock()
mock_workflow_run_resp.run_request = MagicMock()
mock_workflow_run_resp.run_request.workflow_id = "wpid_test_123"
mock_workflow_run_resp.run_request.model_dump = MagicMock(
return_value={"workflow_id": "wpid_test_123", "parameters": {}}
)
# Mock workflow with ForLoop block definition
mock_forloop_definition = MagicMock()
mock_forloop_definition.block_type = BlockType.FOR_LOOP
mock_forloop_definition.label = "process_urls"
mock_forloop_definition.loop_variable_reference = "{{ urls }}"
mock_forloop_definition.model_dump = MagicMock(
return_value={
"block_type": "for_loop",
"label": "process_urls",
"loop_variable_reference": "{{ urls }}",
"loop_blocks": [
{
"block_type": "extraction",
"label": "extract_data",
"data_extraction_goal": "Get page content",
}
],
}
)
mock_workflow = MagicMock()
mock_workflow.model_dump = MagicMock(return_value={"workflow_id": "wf_123", "workflow_definition": {"blocks": []}})
mock_workflow.workflow_definition.blocks = [mock_forloop_definition]
# Mock workflow run blocks - ForLoop parent and extraction child
mock_forloop_run_block = MagicMock()
mock_forloop_run_block.workflow_run_block_id = "wfrb_forloop_123"
mock_forloop_run_block.parent_workflow_run_block_id = None
mock_forloop_run_block.block_type = BlockType.FOR_LOOP
mock_forloop_run_block.label = "process_urls"
mock_forloop_run_block.task_id = None
mock_forloop_run_block.created_at = 1
mock_child_run_block = MagicMock()
mock_child_run_block.workflow_run_block_id = "wfrb_child_456"
mock_child_run_block.parent_workflow_run_block_id = "wfrb_forloop_123"
mock_child_run_block.block_type = "extraction"
mock_child_run_block.label = "extract_data"
mock_child_run_block.task_id = "task_extraction_789"
mock_child_run_block.status = "completed"
mock_child_run_block.output = {"extracted": "data"}
mock_child_run_block.created_at = 2
# Mock task for the child block
mock_task = MagicMock()
mock_task.model_dump = MagicMock(
return_value={
"task_id": "task_extraction_789",
"navigation_goal": "Extract page content",
}
)
# Mock action for the task
mock_action = MagicMock()
mock_action.action_type = "extract"
mock_action.model_dump = MagicMock(
return_value={
"action_type": "extract",
"action_id": "action_123",
}
)
mock_action.get_xpath = MagicMock(return_value="//div[@id='content']")
mock_action.has_mini_agent = False
# Set up patches
with (
patch("skyvern.services.workflow_service.get_workflow_run_response", new_callable=AsyncMock) as mock_get_wfr,
patch("skyvern.core.script_generations.transform_workflow_run.app") as mock_app,
):
mock_get_wfr.return_value = mock_workflow_run_resp
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(
return_value=[
mock_forloop_run_block,
mock_child_run_block,
]
)
# B1 optimization: Mock batch methods instead of individual queries
mock_task.task_id = "task_extraction_789"
mock_action.task_id = "task_extraction_789"
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=[mock_action])
# Call the transformation
result = await transform_workflow_run_to_code_gen_input(
workflow_run_id="wr_test_123",
organization_id="org_test_123",
)
# Verify ForLoop block is included
assert len(result.workflow_blocks) == 1
forloop_block = result.workflow_blocks[0]
assert forloop_block["block_type"] == "for_loop"
assert forloop_block["label"] == "process_urls"
# Verify loop_blocks contain child block with task data
loop_blocks = forloop_block.get("loop_blocks", [])
assert len(loop_blocks) == 1
child_block = loop_blocks[0]
assert child_block["label"] == "extract_data"
assert child_block.get("task_id") == "task_extraction_789"
# Verify actions were collected for the child task
assert "task_extraction_789" in result.actions_by_task
actions = result.actions_by_task["task_extraction_789"]
assert len(actions) == 1
assert actions[0]["action_type"] == "extract"
class TestForLoopScriptExecution:
"""Test that generated ForLoop scripts can be executed."""
def test_forloop_generates_async_for_statement(self) -> None:
"""Verify that ForLoop generates an async for statement."""
forloop_block = {
"block_type": "for_loop",
"label": "iterate_items",
"loop_variable_reference": "{{ items_list }}",
"complete_if_empty": True,
"loop_blocks": [],
}
result = _build_for_loop_statement("iterate_items", forloop_block)
# Verify it's an async for loop
assert isinstance(result, cst.For)
assert result.asynchronous is not None # Has asynchronous keyword
def test_forloop_generates_skyvern_loop_call(self) -> None:
"""Verify that ForLoop generates a skyvern.loop() call."""
forloop_block = {
"block_type": "for_loop",
"label": "iterate_items",
"loop_variable_reference": "{{ items_list }}",
"loop_blocks": [],
}
result = _build_for_loop_statement("iterate_items", forloop_block)
# The iter should be a Call to skyvern.loop
assert isinstance(result.iter, cst.Call)
# Get the function being called
func = result.iter.func
assert isinstance(func, cst.Attribute)
assert func.attr.value == "loop"

View File

@@ -0,0 +1,478 @@
import json
from datetime import datetime, timezone
from enum import StrEnum
from unittest.mock import MagicMock
import pytest
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import FailedToFormatJinjaStyleParameter
from skyvern.forge.sdk.workflow.models.block import (
_JSON_TYPE_MARKER,
HttpRequestBlock,
_json_type_filter,
jinja_sandbox_env,
)
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
class TestJsonTypeFilter:
@pytest.mark.parametrize(
"value",
[
True,
False,
42,
19.99,
None,
[1, 2, 3],
{"a": 1, "b": "hello"},
"hello",
[],
{},
],
)
def test_filter_wraps_with_marker(self, value: object) -> None:
result = _json_type_filter(value)
assert result.startswith(_JSON_TYPE_MARKER)
assert result.endswith(_JSON_TYPE_MARKER)
@pytest.mark.parametrize(
"value",
[
True,
False,
42,
19.99,
None,
[1, 2, 3],
{"a": 1, "b": "hello"},
"hello",
],
)
def test_filter_json_is_parseable(self, value: object) -> None:
result = _json_type_filter(value)
json_part = result[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed == value
def test_filter_handles_datetime(self) -> None:
now = datetime(2024, 1, 15, 12, 30, 45)
result = _json_type_filter(now)
json_part = result[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed == "2024-01-15 12:30:45"
def test_filter_handles_enum(self) -> None:
class Status(StrEnum):
completed = "completed"
failed = "failed"
result = _json_type_filter(Status.completed)
json_part = result[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed == "completed"
def test_filter_handles_nested_datetime_in_dict(self) -> None:
data = {
"status": "completed",
"downloaded_files": [
{"url": "https://example.com/file.pdf", "modified_at": datetime(2024, 1, 15, 12, 30, 45)}
],
}
result = _json_type_filter(data)
json_part = result[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed["downloaded_files"][0]["modified_at"] == "2024-01-15 12:30:45"
class TestJinjaJsonFilterRegistration:
def test_json_filter_is_registered(self) -> None:
assert "json" in jinja_sandbox_env.filters
assert jinja_sandbox_env.filters["json"] == _json_type_filter
@pytest.mark.parametrize(
"template,context,expected_json",
[
("{{ flag | json }}", {"flag": True}, True),
("{{ flag | json }}", {"flag": False}, False),
("{{ count | json }}", {"count": 42}, 42),
("{{ price | json }}", {"price": 19.99}, 19.99),
("{{ null_val | json }}", {"null_val": None}, None),
("{{ items | json }}", {"items": [1, 2, 3]}, [1, 2, 3]),
("{{ data | json }}", {"data": {"a": 1}}, {"a": 1}),
("{{ str_val | json }}", {"str_val": "hello"}, "hello"),
],
)
def test_jinja_renders_json_filter(self, template: str, context: dict, expected_json: object) -> None:
rendered = jinja_sandbox_env.from_string(template).render(context)
# The output should have markers
assert rendered.startswith(_JSON_TYPE_MARKER)
assert rendered.endswith(_JSON_TYPE_MARKER)
# Extract and parse JSON
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed == expected_json
def test_json_filter_with_nested_access(self) -> None:
template = "{{ data.nested.value | json }}"
context = {"data": {"nested": {"value": True}}}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_part) is True
def test_json_filter_with_list_index(self) -> None:
template = "{{ items[1] | json }}"
context = {"items": [10, 20, 30]}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_part) == 20
def test_json_filter_chains_with_default(self) -> None:
template = "{{ missing_val | default(false) | json }}"
context = {} # missing_val not defined
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_part) is False
def test_json_filter_chains_with_default_list(self) -> None:
template = "{{ items | default([]) | json }}"
context = {}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_part) == []
class TestMarkerDetection:
def test_marker_detection_simple(self) -> None:
wrapped = _json_type_filter(True)
assert wrapped.startswith(_JSON_TYPE_MARKER)
assert wrapped.endswith(_JSON_TYPE_MARKER)
# Simulate the detection logic from _render_templates_in_json
json_str = wrapped[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_str) is True
def test_marker_detection_complex_object(self) -> None:
complex_obj = {"users": [{"name": "Alice", "active": True}], "count": 1}
wrapped = _json_type_filter(complex_obj)
json_str = wrapped[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_str) == complex_obj
def test_plain_string_not_detected_as_marker(self) -> None:
plain_string = "hello world"
assert not plain_string.startswith(_JSON_TYPE_MARKER)
assert not plain_string.endswith(_JSON_TYPE_MARKER)
def test_partial_marker_not_detected(self) -> None:
start_only = f"{_JSON_TYPE_MARKER}true"
end_only = f"true{_JSON_TYPE_MARKER}"
assert not (start_only.startswith(_JSON_TYPE_MARKER) and start_only.endswith(_JSON_TYPE_MARKER))
assert not (end_only.startswith(_JSON_TYPE_MARKER) and end_only.endswith(_JSON_TYPE_MARKER))
class TestWithoutJsonFilter:
def test_standard_template_renders_string(self) -> None:
template = "{{ flag }}"
context = {"flag": True}
rendered = jinja_sandbox_env.from_string(template).render(context)
assert rendered == "True" # Python str(True)
assert not rendered.startswith(_JSON_TYPE_MARKER)
def test_standard_template_integer(self) -> None:
template = "{{ count }}"
context = {"count": 42}
rendered = jinja_sandbox_env.from_string(template).render(context)
assert rendered == "42"
assert not rendered.startswith(_JSON_TYPE_MARKER)
class TestEdgeCasesAndLimitations:
def test_mixed_template_jinja_output_contains_marker(self) -> None:
"""Jinja renders mixed templates with markers embedded in output.
This verifies what Jinja produces. The actual error handling happens
in _render_templates_in_json (tested separately below).
"""
template = "prefix_{{ flag | json }}_suffix"
context = {"flag": True}
rendered = jinja_sandbox_env.from_string(template).render(context)
# The output contains the marker because it's mixed with prefix/suffix
assert _JSON_TYPE_MARKER in rendered
assert rendered.startswith("prefix_")
assert rendered.endswith("_suffix")
# The marker detection logic (startswith AND endswith) will NOT match
assert not (rendered.startswith(_JSON_TYPE_MARKER) and rendered.endswith(_JSON_TYPE_MARKER))
def test_deeply_nested_structure(self) -> None:
template = "{{ data | json }}"
context = {
"data": {
"level1": {
"level2": {
"level3": {
"items": [1, 2, 3],
"active": True,
}
}
}
}
}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed == context["data"]
assert parsed["level1"]["level2"]["level3"]["active"] is True
def test_special_characters_in_string_value(self) -> None:
template = "{{ text | json }}"
context = {"text": 'Hello "World"\nNew line\ttab'}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed == context["text"]
def test_unicode_characters(self) -> None:
template = "{{ text | json }}"
context = {"text": "Hello \u4e16\u754c \U0001f600"} # Chinese + emoji
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed == context["text"]
def test_empty_values(self) -> None:
# Empty string
template = "{{ text | json }}"
context: dict[str, object] = {"text": ""}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_part) == ""
# Empty list
context = {"text": []}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_part) == []
# Empty dict
context = {"text": {}}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
assert json.loads(json_part) == {}
class TestEmbeddedMarkerErrorHandling:
"""Tests that embedded markers (| json mixed with other text) raise clear errors."""
def test_embedded_marker_raises_error(self) -> None:
"""Using | json with prefix/suffix text should raise FailedToFormatJinjaStyleParameter."""
now = datetime.now(timezone.utc)
output_param = OutputParameter(
parameter_type=ParameterType.OUTPUT,
key="http_output",
description=None,
output_parameter_id="output-1",
workflow_id="workflow-1",
created_at=now,
modified_at=now,
deleted_at=None,
)
block = HttpRequestBlock(
label="test-block",
url="https://example.com",
method="POST",
body={"id": "prefix-{{ val | json }}"},
output_parameter=output_param,
)
mock_context = MagicMock()
mock_context.values = {"val": 123}
mock_context.secrets = {}
mock_context.include_secrets_in_templates = False
mock_context.get_block_metadata = MagicMock(return_value={})
with pytest.raises(FailedToFormatJinjaStyleParameter) as exc_info:
block.format_potential_template_parameters(mock_context)
assert "can only be used for complete value replacement" in str(exc_info.value)
def test_valid_json_filter_does_not_raise(self) -> None:
"""Using | json for complete value replacement should work without error."""
now = datetime.now(timezone.utc)
output_param = OutputParameter(
parameter_type=ParameterType.OUTPUT,
key="http_output",
description=None,
output_parameter_id="output-1",
workflow_id="workflow-1",
created_at=now,
modified_at=now,
deleted_at=None,
)
block = HttpRequestBlock(
label="test-block",
url="https://example.com",
method="POST",
body={"enabled": "{{ flag | json }}", "count": "{{ num | json }}"},
output_parameter=output_param,
)
mock_context = MagicMock()
mock_context.values = {"flag": True, "num": 42}
mock_context.secrets = {}
mock_context.include_secrets_in_templates = False
mock_context.get_block_metadata = MagicMock(return_value={})
# Should not raise
block.format_potential_template_parameters(mock_context)
# Verify the values were correctly converted to native types
assert block.body == {"enabled": True, "count": 42}
class TestWorkflowRunSummary:
"""Tests for the workflow_run_summary template variable."""
def test_build_workflow_run_summary_empty_outputs(self) -> None:
"""Test summary with no block outputs."""
context = MagicMock()
context.workflow_run_id = "wr_123"
context.workflow_run_outputs = {}
# Create a real context to test the method
summary = WorkflowRunContext.build_workflow_run_summary(context)
assert summary["workflow_run_id"] == "wr_123"
assert summary["status"] is None
assert summary["output"] == {"extracted_information": {}}
assert summary["downloaded_files"] == []
assert summary["errors"] == []
assert summary["failure_reason"] is None
def test_build_workflow_run_summary_merges_extracted_information(self) -> None:
"""Test that output.extracted_information is merged from all blocks."""
context = MagicMock()
context.workflow_run_id = "wr_456"
context.workflow_run_outputs = {
"NavigationBlock": {
"status": "completed",
"extracted_information": {"title": "Example Page"},
"errors": [],
"downloaded_files": [],
},
"ExtractionBlock": {
"status": "completed",
"extracted_information": {"documents": [{"name": "doc1.pdf"}]},
"errors": [],
"downloaded_files": [],
},
}
summary = WorkflowRunContext.build_workflow_run_summary(context)
# extracted_information is merged from all blocks (flattened, not keyed by block label)
assert summary["output"]["extracted_information"] == {
"title": "Example Page",
"documents": [{"name": "doc1.pdf"}],
}
def test_build_workflow_run_summary_aggregates_downloaded_files(self) -> None:
"""Test that downloaded_files are aggregated from all blocks."""
context = MagicMock()
context.workflow_run_id = "wr_789"
context.workflow_run_outputs = {
"Block1": {
"status": "completed",
"downloaded_files": [{"url": "file1.pdf"}],
"errors": [],
},
"Block2": {
"status": "completed",
"downloaded_files": [{"url": "file2.pdf"}, {"url": "file3.pdf"}],
"errors": [],
},
}
summary = WorkflowRunContext.build_workflow_run_summary(context)
assert len(summary["downloaded_files"]) == 3
assert {"url": "file1.pdf"} in summary["downloaded_files"]
assert {"url": "file2.pdf"} in summary["downloaded_files"]
assert {"url": "file3.pdf"} in summary["downloaded_files"]
def test_build_workflow_run_summary_aggregates_errors(self) -> None:
"""Test that errors are aggregated from all blocks."""
context = MagicMock()
context.workflow_run_id = "wr_errors"
context.workflow_run_outputs = {
"Block1": {
"status": "failed",
"errors": [{"message": "Error 1"}],
"failure_reason": "Block 1 failed",
},
"Block2": {
"status": "completed",
"errors": [{"message": "Warning"}],
},
}
summary = WorkflowRunContext.build_workflow_run_summary(context)
assert len(summary["errors"]) == 2
assert {"message": "Error 1"} in summary["errors"]
assert {"message": "Warning"} in summary["errors"]
assert summary["failure_reason"] == "Block 1 failed"
def test_build_workflow_run_summary_uses_last_status(self) -> None:
"""Test that the last block's status is used."""
context = MagicMock()
context.workflow_run_id = "wr_status"
context.workflow_run_outputs = {
"Block1": {"status": "completed", "errors": []},
"Block2": {"status": "failed", "errors": []},
"Block3": {"status": "completed", "errors": []},
}
summary = WorkflowRunContext.build_workflow_run_summary(context)
# Last block's status is used
assert summary["status"] == "completed"
def test_status_converted_to_string_in_summary(self) -> None:
"""Test that TaskStatus enum is converted to string in summary."""
context = MagicMock()
context.workflow_run_id = "wr_enum"
context.workflow_run_outputs = {
"Block1": {"status": TaskStatus.completed, "errors": []},
"Block2": {"status": TaskStatus.timed_out, "errors": []},
}
summary = WorkflowRunContext.build_workflow_run_summary(context)
assert summary["status"] == "timed_out"
assert type(summary["status"]) is str # Not TaskStatus enum
def test_workflow_run_summary_with_json_filter(self) -> None:
"""Test that workflow_run_summary works with | json filter in templates."""
template = "{{ workflow_run_summary | json }}"
context = {
"workflow_run_summary": {
"workflow_run_id": "wr_template",
"status": "completed",
"output": {"extracted_information": {"documents": [{"name": "doc1.pdf"}]}},
"downloaded_files": [{"url": "file.pdf"}],
"errors": [],
"failure_reason": None,
}
}
rendered = jinja_sandbox_env.from_string(template).render(context)
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
parsed = json.loads(json_part)
assert parsed["workflow_run_id"] == "wr_template"
assert parsed["status"] == "completed"
assert parsed["output"]["extracted_information"]["documents"] == [{"name": "doc1.pdf"}]
assert parsed["downloaded_files"] == [{"url": "file.pdf"}]
assert parsed["errors"] == []
assert parsed["failure_reason"] is None

View File

@@ -0,0 +1,9 @@
import pytest
from skyvern.forge.sdk.db import id as id_module
def test_generate_id_uniqueness_with_overflow(monkeypatch: pytest.MonkeyPatch) -> None:
total_ids = 10000
generated_ids = [id_module.generate_id() for _ in range(total_ids)]
assert len(set(generated_ids)) == total_ids

View File

@@ -0,0 +1,22 @@
import pytest
from skyvern.forge.sdk.api.llm.utils import _coerce_response_to_dict
@pytest.mark.parametrize(
("response", "expected"),
[
({"page_info": "Select country"}, ({"page_info": "Select country"}, False)),
([{"page_info": "First"}, {"page_info": "Second"}], ({"page_info": "First"}, False)),
(["text", {"page_info": "First dict"}], ({"page_info": "First dict"}, False)),
([1, 2, 3], ({}, True)),
("not-a-dict", ({}, True)),
([], ({}, True)),
],
)
def test_coerce_response_to_dict_variants(response, expected):
try:
parsed = _coerce_response_to_dict(response)
assert parsed == expected[0]
except Exception:
assert expected[1]

View File

@@ -0,0 +1,62 @@
"""Tests for MCP block tools (skyvern_block_schema, skyvern_block_validate)."""
from __future__ import annotations
import json
import pytest
from skyvern.cli.mcp_tools.blocks import skyvern_block_schema, skyvern_block_validate
@pytest.mark.asyncio
async def test_block_schema_task_redirects_to_navigation() -> None:
"""Requesting schema for 'task' should return navigation info with a deprecation warning."""
result = await skyvern_block_schema(block_type="task")
assert result["ok"] is True
assert result["data"]["block_type"] == "navigation"
assert "navigation_goal" in result["data"]["schema"].get("properties", {})
assert len(result["warnings"]) > 0
assert any("deprecated" in w.lower() for w in result["warnings"])
@pytest.mark.asyncio
async def test_block_schema_unknown_type_returns_error() -> None:
"""Requesting schema for a nonexistent type should return an error with available types."""
result = await skyvern_block_schema(block_type="invalid_xyz")
assert result["ok"] is False
assert result["error"] is not None
assert "invalid_xyz" in result["error"]["message"]
assert "navigation" in result["error"]["hint"]
@pytest.mark.asyncio
async def test_block_validate_task_type_warns_deprecated() -> None:
"""Validating a 'task' block should succeed with a deprecation warning."""
block = {
"block_type": "task",
"label": "test",
"url": "https://example.com",
"navigation_goal": "do something",
}
result = await skyvern_block_validate(block_json=json.dumps(block))
assert result["ok"] is True
assert result["data"]["valid"] is True
assert len(result["warnings"]) > 0
assert any("deprecated" in w.lower() for w in result["warnings"])
@pytest.mark.asyncio
async def test_block_schema_no_type_lists_all() -> None:
"""Calling without a block_type should list all available types."""
result = await skyvern_block_schema(block_type=None)
assert result["ok"] is True
block_types = result["data"]["block_types"]
assert "navigation" in block_types
assert "extraction" in block_types
assert "task" not in block_types
assert result["data"]["count"] > 0

View File

@@ -0,0 +1,270 @@
"""Tests for multi-field TOTP support in script generation."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.core.script_generations.generate_script import _annotate_multi_field_totp_sequence
from skyvern.core.script_generations.script_skyvern_page import ScriptSkyvernPage
from skyvern.webeye.actions.action_types import ActionType
class TestAnnotateMultiFieldTotpSequence:
"""Tests for _annotate_multi_field_totp_sequence function."""
def test_empty_actions(self) -> None:
"""Empty action list returns unchanged."""
result = _annotate_multi_field_totp_sequence([])
assert result == []
def test_less_than_4_actions_returns_unchanged(self) -> None:
"""Actions with fewer than 4 items return unchanged (minimum for TOTP)."""
actions = [
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "1"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "2"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "3"},
]
result = _annotate_multi_field_totp_sequence(actions)
# No totp_timing_info should be added
for action in result:
assert "totp_timing_info" not in action
def test_4_digit_sequence_gets_annotated(self) -> None:
"""4 consecutive single-digit inputs with same field_name get annotated."""
actions = [
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp_code", "text": "1"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp_code", "text": "2"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp_code", "text": "3"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp_code", "text": "4"},
]
result = _annotate_multi_field_totp_sequence(actions)
for idx, action in enumerate(result):
assert "totp_timing_info" in action
assert action["totp_timing_info"]["is_totp_sequence"] is True
assert action["totp_timing_info"]["action_index"] == idx
assert action["totp_timing_info"]["total_digits"] == 4
assert action["totp_timing_info"]["field_name"] == "totp_code"
def test_6_digit_sequence_gets_annotated(self) -> None:
"""Standard 6-digit TOTP sequence gets properly annotated."""
actions = [
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "1"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "2"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "3"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "4"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "5"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "6"},
]
result = _annotate_multi_field_totp_sequence(actions)
for idx, action in enumerate(result):
assert action["totp_timing_info"]["action_index"] == idx
assert action["totp_timing_info"]["total_digits"] == 6
def test_8_digit_sequence_gets_annotated(self) -> None:
"""8-digit sequence (some TOTP implementations) gets annotated."""
actions = [{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": str(i)} for i in range(8)]
result = _annotate_multi_field_totp_sequence(actions)
assert all("totp_timing_info" in a for a in result)
assert result[0]["totp_timing_info"]["total_digits"] == 8
assert result[7]["totp_timing_info"]["action_index"] == 7
def test_3_digits_not_annotated(self) -> None:
"""3 consecutive digits should NOT be annotated (minimum is 4)."""
actions = [
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "1"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "2"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "3"},
{"action_type": ActionType.CLICK, "element_id": "submit"},
]
result = _annotate_multi_field_totp_sequence(actions)
for action in result:
assert "totp_timing_info" not in action
def test_different_field_names_not_grouped(self) -> None:
"""Actions with different field_names should not be grouped together."""
actions = [
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp1", "text": "1"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp1", "text": "2"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp2", "text": "3"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp2", "text": "4"},
]
result = _annotate_multi_field_totp_sequence(actions)
# Neither sequence has 4+ with same field_name
for action in result:
assert "totp_timing_info" not in action
def test_mixed_actions_with_totp_sequence(self) -> None:
"""TOTP sequence surrounded by non-TOTP actions still gets annotated."""
actions = [
{"action_type": ActionType.CLICK, "element_id": "show_totp"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "1"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "2"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "3"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "4"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "5"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "6"},
{"action_type": ActionType.CLICK, "element_id": "submit"},
]
result = _annotate_multi_field_totp_sequence(actions)
# First and last actions should not have totp_timing_info
assert "totp_timing_info" not in result[0]
assert "totp_timing_info" not in result[7]
# Middle 6 actions should be annotated
for idx in range(1, 7):
assert "totp_timing_info" in result[idx]
assert result[idx]["totp_timing_info"]["action_index"] == idx - 1
assert result[idx]["totp_timing_info"]["total_digits"] == 6
def test_multiple_sequences_in_action_list(self) -> None:
"""Multiple separate TOTP sequences in same action list get annotated separately."""
actions = [
# First sequence - 4 digits
{"action_type": ActionType.INPUT_TEXT, "field_name": "code1", "text": "1"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code1", "text": "2"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code1", "text": "3"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code1", "text": "4"},
# Non-TOTP action breaks the sequence
{"action_type": ActionType.CLICK, "element_id": "next"},
# Second sequence - 6 digits
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "5"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "6"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "7"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "8"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "9"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "0"},
]
result = _annotate_multi_field_totp_sequence(actions)
# First sequence (indices 0-3)
for idx in range(4):
assert result[idx]["totp_timing_info"]["total_digits"] == 4
assert result[idx]["totp_timing_info"]["field_name"] == "code1"
# Click action (index 4)
assert "totp_timing_info" not in result[4]
# Second sequence (indices 5-10)
for idx in range(5, 11):
assert result[idx]["totp_timing_info"]["total_digits"] == 6
assert result[idx]["totp_timing_info"]["field_name"] == "code2"
assert result[idx]["totp_timing_info"]["action_index"] == idx - 5
def test_non_digit_text_not_annotated(self) -> None:
"""Actions with non-digit text should not be considered TOTP."""
actions = [
{"action_type": ActionType.INPUT_TEXT, "field_name": "field", "text": "a"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "field", "text": "b"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "field", "text": "c"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "field", "text": "d"},
]
result = _annotate_multi_field_totp_sequence(actions)
for action in result:
assert "totp_timing_info" not in action
def test_multi_digit_text_not_annotated(self) -> None:
"""Actions with multi-digit text should not be considered multi-field TOTP."""
actions = [
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "12"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "34"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "56"},
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "78"},
]
result = _annotate_multi_field_totp_sequence(actions)
for action in result:
assert "totp_timing_info" not in action
def test_missing_field_name_not_annotated(self) -> None:
"""Actions without field_name should not be considered TOTP."""
actions = [
{"action_type": ActionType.INPUT_TEXT, "text": "1"},
{"action_type": ActionType.INPUT_TEXT, "text": "2"},
{"action_type": ActionType.INPUT_TEXT, "text": "3"},
{"action_type": ActionType.INPUT_TEXT, "text": "4"},
]
result = _annotate_multi_field_totp_sequence(actions)
for action in result:
assert "totp_timing_info" not in action
class TestGetTotpDigitBasic:
"""Basic tests for get_totp_digit in ScriptSkyvernPage."""
@pytest.fixture
def mock_skyvern_context(self) -> MagicMock:
"""Create a mock skyvern context."""
ctx = MagicMock()
ctx.workflow_run_id = "wfr_test123"
return ctx
@pytest.mark.asyncio
async def test_returns_single_digit(
self,
mock_skyvern_context: MagicMock,
) -> None:
"""get_totp_digit should return a single digit string."""
# Empty credentials - will fall back to get_actual_value
mock_workflow_context = MagicMock()
mock_workflow_context.values = {}
with patch("skyvern.core.script_generations.script_skyvern_page.skyvern_context") as mock_ctx_module:
with patch("skyvern.core.script_generations.script_skyvern_page.app") as mock_app:
mock_ctx_module.ensure_context.return_value = mock_skyvern_context
mock_app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context = AsyncMock(
return_value=mock_workflow_context
)
page = MagicMock(spec=ScriptSkyvernPage)
page._totp_sequence_cache = {}
page.get_actual_value = AsyncMock(return_value="123456")
result = await ScriptSkyvernPage.get_totp_digit(
page,
context=MagicMock(),
field_name="totp_code",
digit_index=0,
)
# Should return a single digit
assert len(result) == 1
assert result.isdigit()
assert result == "1" # First digit of "123456"
@pytest.mark.asyncio
async def test_returns_correct_digit_index(
self,
mock_skyvern_context: MagicMock,
) -> None:
"""get_totp_digit should return the correct digit for the given index."""
mock_workflow_context = MagicMock()
mock_workflow_context.values = {}
with patch("skyvern.core.script_generations.script_skyvern_page.skyvern_context") as mock_ctx_module:
with patch("skyvern.core.script_generations.script_skyvern_page.app") as mock_app:
mock_ctx_module.ensure_context.return_value = mock_skyvern_context
mock_app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context = AsyncMock(
return_value=mock_workflow_context
)
page = MagicMock(spec=ScriptSkyvernPage)
page._totp_sequence_cache = {}
page.get_actual_value = AsyncMock(return_value="987654")
# Test each digit index
for idx, expected in enumerate("987654"):
result = await ScriptSkyvernPage.get_totp_digit(
page,
context=MagicMock(),
field_name="totp_code",
digit_index=idx,
)
assert result == expected, f"Expected digit {expected} at index {idx}, got {result}"

View File

@@ -0,0 +1,471 @@
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
from zoneinfo import ZoneInfo
import pytest
from skyvern.forge.agent import ForgeAgent, SpeculativePlan
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.schemas.runs import RunEngine
from skyvern.schemas.steps import AgentStepOutput
from skyvern.webeye.actions.actions import ClickAction, CompleteAction, ExtractAction
from skyvern.webeye.actions.responses import ActionSuccess
from skyvern.webeye.scraper.scraped_page import ScrapedPage
from tests.unit.helpers import (
make_browser_state,
make_organization,
make_step,
make_task,
setup_parallel_verification_mocks,
)
@pytest.mark.asyncio
async def test_parallel_verification_triggers_data_extraction(monkeypatch: pytest.MonkeyPatch) -> None:
agent = ForgeAgent()
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step_output = AgentStepOutput(action_results=[], actions_and_results=[])
step = make_step(
now,
task,
step_id="step-123",
status=StepStatus.completed,
order=0,
output=step_output,
)
next_step = make_step(
now,
task,
step_id="step-next",
status=StepStatus.created,
order=1,
output=None,
)
complete_action = CompleteAction(reasoning="done", verified=True)
extract_action = ExtractAction(
reasoning="extract final data",
data_extraction_goal=task.data_extraction_goal,
data_extraction_schema=task.extracted_information_schema,
)
extract_action.organization_id = task.organization_id
extract_action.workflow_run_id = task.workflow_run_id
extract_action.task_id = task.task_id
extract_action.step_id = step.step_id
extract_action.step_order = step.order
extract_action.action_order = 1
monkeypatch.setattr(agent, "create_extract_action", AsyncMock(return_value=extract_action))
extraction_payload = {"quote": "42%"}
mocks = setup_parallel_verification_mocks(
agent,
step=step,
task=task,
monkeypatch=monkeypatch,
next_step=next_step,
complete_action=complete_action,
handle_action_responses=[
[ActionSuccess()],
[ActionSuccess(data=extraction_payload)],
],
extract_action=extract_action,
)
browser_state, scraped_page, page = make_browser_state()
completed, last_step, next_created_step = await agent._handle_completed_step_with_parallel_verification(
organization=organization,
task=task,
step=step,
page=page,
browser_state=browser_state,
scraped_page=scraped_page,
engine=RunEngine.skyvern_v1,
)
assert completed is True
assert last_step == step
assert next_created_step is None
assert mocks.handle_action.await_count == 2
extracted_information = mocks.update_task.await_args.kwargs["extracted_information"]
assert extracted_information == extraction_payload
@pytest.mark.asyncio
async def test_parallel_verification_skips_extraction_without_navigation_goal(
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent = ForgeAgent()
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization, navigation_goal=None)
step_output = AgentStepOutput(action_results=[], actions_and_results=[])
step = make_step(
now,
task,
step_id="step-123",
status=StepStatus.completed,
order=0,
output=step_output,
)
setup_parallel_verification_mocks(
agent,
step=step,
task=task,
monkeypatch=monkeypatch,
next_step=step,
complete_action=CompleteAction(reasoning="done", verified=True),
handle_action_responses=[[ActionSuccess()]],
)
run_data_extraction_mock = AsyncMock()
monkeypatch.setattr(agent, "_run_data_extraction_after_complete_action", run_data_extraction_mock)
browser_state, scraped_page, page = make_browser_state()
await agent._handle_completed_step_with_parallel_verification(
organization=organization,
task=task,
step=step,
page=page,
browser_state=browser_state,
scraped_page=scraped_page,
engine=RunEngine.skyvern_v1,
)
run_data_extraction_mock.assert_not_awaited()
def test_task_validate_update_requires_extracted_information() -> None:
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(
now,
organization,
data_extraction_goal="Need data",
)
with pytest.raises(ValueError):
task.validate_update(TaskStatus.completed, extracted_information=None)
@pytest.mark.asyncio
async def test_agent_step_skips_user_goal_check_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
agent = ForgeAgent()
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization, navigation_goal="Reach confirmation page", workflow_run_id="workflow-1")
step = make_step(
now,
task,
step_id="step-disable",
status=StepStatus.created,
order=0,
output=None,
)
browser_state, _, page = make_browser_state()
browser_state.must_get_working_page = AsyncMock(return_value=page)
browser_state.get_working_page = AsyncMock(return_value=page)
async def _dummy_cleanup(*_args, **_kwargs) -> list[dict]:
return []
scraped_page = ScrapedPage(
elements=[],
element_tree=[],
element_tree_trimmed=[],
_browser_state=browser_state,
_clean_up_func=_dummy_cleanup,
_scrape_exclude=None,
)
scraped_page.screenshots = [b"image"]
agent.build_and_record_step_prompt = AsyncMock(return_value=(scraped_page, "prompt", False, "extract-actions"))
json_response: dict[str, object] = {"actions": [{"action_type": "CLICK", "element_id": "node-1"}]}
agent.handle_potential_OTP_actions = AsyncMock(return_value=(json_response, []))
click_action = ClickAction(
element_id="node-1",
organization_id=task.organization_id,
workflow_run_id=task.workflow_run_id,
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
action_order=0,
)
monkeypatch.setattr("skyvern.forge.agent.parse_actions", lambda *_, **__: [click_action])
action_handler_mock = AsyncMock(return_value=[ActionSuccess()])
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", action_handler_mock)
agent.record_artifacts_after_action = AsyncMock()
agent._is_multi_field_totp_sequence = MagicMock(return_value=False)
agent.check_user_goal_complete = AsyncMock()
llm_handler_mock = AsyncMock(return_value=json_response)
monkeypatch.setattr(
"skyvern.forge.agent.LLMAPIHandlerFactory.get_override_llm_api_handler",
lambda *_args, **_kwargs: llm_handler_mock,
)
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.prepare_step_execution", AsyncMock())
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.post_action_execution", AsyncMock())
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", AsyncMock(return_value=None))
monkeypatch.setattr("skyvern.forge.agent.random.uniform", lambda *_args, **_kwargs: 0)
async def fake_update_step(
step: Step,
status: StepStatus | None = None,
output=None,
is_last: bool | None = None,
retry_index: int | None = None,
**_kwargs,
) -> Step:
if status is not None:
step.status = status
if output is not None:
step.output = output
if is_last is not None:
step.is_last = is_last
if retry_index is not None:
step.retry_index = retry_index
return step
agent.update_step = AsyncMock(side_effect=fake_update_step)
async def feature_flag_side_effect(flag_name: str, *_args, **_kwargs) -> bool:
if flag_name == "DISABLE_USER_GOAL_CHECK":
return True
return False
monkeypatch.setattr(
"skyvern.forge.agent.app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached",
AsyncMock(side_effect=feature_flag_side_effect),
)
context = SkyvernContext(
task_id=task.task_id,
step_id=None,
organization_id=task.organization_id,
workflow_run_id=task.workflow_run_id,
tz_info=ZoneInfo("UTC"),
)
skyvern_context.set(context)
try:
completed_step, detailed_output = await agent.agent_step(
task=task,
step=step,
browser_state=browser_state,
organization=organization,
)
finally:
skyvern_context.reset()
assert completed_step.status == StepStatus.completed
assert detailed_output.actions_and_results is not None
assert action_handler_mock.await_count == 1
agent.record_artifacts_after_action.assert_awaited()
agent.check_user_goal_complete.assert_not_called()
@pytest.mark.asyncio
async def test_agent_step_persists_artifacts_when_using_speculative_plan(
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent = ForgeAgent()
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization, navigation_goal=None)
step = make_step(
now,
task,
step_id="step-speculative",
status=StepStatus.created,
order=0,
output=None,
)
browser_state, _, page = make_browser_state()
browser_state.must_get_working_page = AsyncMock(return_value=page)
browser_state.get_working_page = AsyncMock(return_value=page)
async def _dummy_cleanup(*_args, **_kwargs) -> list[dict]:
return []
scraped_page = ScrapedPage(
elements=[],
element_tree=[{"tagName": "div", "children": []}],
element_tree_trimmed=[{"tagName": "div", "children": []}],
_browser_state=browser_state,
_clean_up_func=_dummy_cleanup,
_scrape_exclude=None,
)
scraped_page.html = "<html></html>"
scraped_page.id_to_css_dict = {"node-1": "#node"}
scraped_page.id_to_frame_dict = {"node-1": "frame-1"}
scraped_page.screenshots = [b"image"]
speculative_plan = SpeculativePlan(
scraped_page=scraped_page,
extract_action_prompt="unused",
use_caching=False,
llm_json_response=None,
llm_metadata=None,
prompt_name="extract-actions",
)
extract_action = ExtractAction(
reasoning="collect data",
data_extraction_goal=task.data_extraction_goal,
data_extraction_schema=task.extracted_information_schema,
)
extract_action.organization_id = task.organization_id
extract_action.workflow_run_id = task.workflow_run_id
extract_action.task_id = task.task_id
extract_action.step_id = step.step_id
extract_action.step_order = step.order
extract_action.action_order = 0
agent.create_extract_action = AsyncMock(return_value=extract_action)
agent.record_artifacts_after_action = AsyncMock()
agent._persist_scrape_artifacts = AsyncMock()
agent._is_multi_field_totp_sequence = MagicMock(return_value=False)
action_handler_mock = AsyncMock(return_value=[ActionSuccess()])
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", action_handler_mock)
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.prepare_step_execution", AsyncMock())
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.post_action_execution", AsyncMock())
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", AsyncMock(return_value=None))
monkeypatch.setattr("skyvern.forge.agent.random.uniform", lambda *_args, **_kwargs: 0)
monkeypatch.setattr("skyvern.forge.agent.app.DATABASE.create_action", AsyncMock())
monkeypatch.setattr(
"skyvern.forge.agent.app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached",
AsyncMock(return_value=False),
)
async def fake_update_step(
step: Step,
status: StepStatus | None = None,
output=None,
is_last: bool | None = None,
retry_index: int | None = None,
**_kwargs,
) -> Step:
if status is not None:
step.status = status
if output is not None:
step.output = output
if is_last is not None:
step.is_last = is_last
if retry_index is not None:
step.retry_index = retry_index
return step
agent.update_step = AsyncMock(side_effect=fake_update_step)
context = SkyvernContext(
task_id=task.task_id,
step_id=None,
organization_id=task.organization_id,
workflow_run_id=task.workflow_run_id,
tz_info=ZoneInfo("UTC"),
)
context.speculative_plans[step.step_id] = speculative_plan
skyvern_context.set(context)
try:
completed_step, detailed_output = await agent.agent_step(
task=task,
step=step,
browser_state=browser_state,
organization=organization,
)
finally:
skyvern_context.reset()
assert completed_step.status == StepStatus.completed
assert detailed_output.actions is not None
agent._persist_scrape_artifacts.assert_awaited_once()
@pytest.mark.asyncio
async def test_persist_scrape_artifacts_records_all_files(monkeypatch: pytest.MonkeyPatch) -> None:
agent = ForgeAgent()
now = datetime.now(UTC)
organization = make_organization(now)
task = make_task(now, organization)
step = make_step(
now,
task,
step_id="step-artifacts",
status=StepStatus.created,
order=0,
output=None,
)
browser_state, _, _ = make_browser_state()
async def _dummy_cleanup(*_args, **_kwargs) -> list[dict]:
return []
scraped_page = ScrapedPage(
elements=[],
element_tree=[{"tagName": "div"}],
element_tree_trimmed=[{"tagName": "div"}],
_browser_state=browser_state,
_clean_up_func=_dummy_cleanup,
_scrape_exclude=None,
)
scraped_page.html = "<html></html>"
scraped_page.id_to_css_dict = {"node-1": "#node"}
scraped_page.id_to_frame_dict = {"node-1": "frame-1"}
scraped_page.element_tree = [{"tagName": "div"}]
scraped_page.element_tree_trimmed = [{"tagName": "div"}]
economy_tree_mock = MagicMock(return_value="<economy>")
full_tree_mock = MagicMock(return_value="<full>")
def economy_wrapper(self, *args, **kwargs):
return economy_tree_mock(self, *args, **kwargs)
def full_wrapper(self, *args, **kwargs):
return full_tree_mock(self, *args, **kwargs)
monkeypatch.setattr(ScrapedPage, "build_economy_elements_tree", economy_wrapper)
monkeypatch.setattr(ScrapedPage, "build_element_tree", full_wrapper)
artifact_mock = AsyncMock()
monkeypatch.setattr("skyvern.forge.agent.app.ARTIFACT_MANAGER.create_artifact", artifact_mock)
context = SkyvernContext(
task_id=task.task_id,
step_id=None,
organization_id=task.organization_id,
workflow_run_id=task.workflow_run_id,
tz_info=ZoneInfo("UTC"),
)
context.enable_speed_optimizations = True
await agent._persist_scrape_artifacts(
task=task,
step=step,
scraped_page=scraped_page,
context=context,
)
assert artifact_mock.await_count == 6
economy_tree_mock.assert_called_once()
full_tree_mock.assert_not_called()
last_call = artifact_mock.await_args_list[-1]
assert last_call.kwargs["artifact_type"] == ArtifactType.VISIBLE_ELEMENTS_TREE_IN_PROMPT
assert last_call.kwargs["data"] == b"<economy>"

View File

@@ -0,0 +1,58 @@
from unittest.mock import AsyncMock
import pytest
from skyvern.forge import app
from skyvern.forge.agent import (
EXTRACT_ACTION_PROMPT_NAME,
EXTRACT_ACTION_TEMPLATE,
ForgeAgent,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
@pytest.mark.asyncio
async def test_prompt_caching_settings_respect_experiment(monkeypatch):
agent = ForgeAgent()
context = SkyvernContext(run_id="wr_123", organization_id="org_456")
mock_provider = AsyncMock()
mock_provider.is_feature_enabled_cached.return_value = True
monkeypatch.setattr(app, "EXPERIMENTATION_PROVIDER", mock_provider)
try:
LLMAPIHandlerFactory.set_prompt_caching_settings(None)
settings = await agent._get_prompt_caching_settings(context)
assert settings == {
EXTRACT_ACTION_PROMPT_NAME: True,
EXTRACT_ACTION_TEMPLATE: True,
}
mock_provider.is_feature_enabled_cached.assert_awaited_once_with(
"PROMPT_CACHING_OPTIMIZATION",
"wr_123",
properties={"organization_id": "org_456"},
)
# Cached on context; no second provider call
await agent._get_prompt_caching_settings(context)
assert mock_provider.is_feature_enabled_cached.await_count == 1
finally:
LLMAPIHandlerFactory.set_prompt_caching_settings(None)
@pytest.mark.asyncio
async def test_prompt_caching_settings_use_override(monkeypatch):
agent = ForgeAgent()
context = SkyvernContext(run_id="wr_789", organization_id="org_987")
mock_provider = AsyncMock()
monkeypatch.setattr(app, "EXPERIMENTATION_PROVIDER", mock_provider)
try:
LLMAPIHandlerFactory.set_prompt_caching_settings({"extract-actions": True})
settings = await agent._get_prompt_caching_settings(context)
assert settings == {"extract-actions": True}
mock_provider.is_feature_enabled_cached.assert_not_called()
finally:
LLMAPIHandlerFactory.set_prompt_caching_settings(None)

View File

@@ -0,0 +1,97 @@
from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text
def test_sanitize_postgres_text__normal_text() -> None:
"""Test that normal text passes through unchanged."""
normal_text = "Hello, World! This is a normal PDF text with numbers 123 and symbols @#$%."
result = sanitize_postgres_text(normal_text)
assert result == normal_text
def test_sanitize_postgres_text__with_nul_bytes() -> None:
"""Test that NUL bytes (0x00) are removed."""
text_with_nul = "Hello\x00World\x00Test"
expected = "HelloWorldTest"
result = sanitize_postgres_text(text_with_nul)
assert result == expected
def test_sanitize_postgres_text__with_control_characters() -> None:
"""Test that problematic control characters are removed."""
# Test various control characters that should be removed
text_with_controls = "Hello\x01\x02\x03World\x08\x0b\x0c\x0e\x1fTest"
expected = "HelloWorldTest"
result = sanitize_postgres_text(text_with_controls)
assert result == expected
def test_sanitize_postgres_text__preserve_whitespace() -> None:
"""Test that common whitespace characters are preserved."""
text_with_whitespace = "Hello\tWorld\nNew Line\rCarriage Return"
result = sanitize_postgres_text(text_with_whitespace)
assert result == text_with_whitespace
assert "\t" in result
assert "\n" in result
assert "\r" in result
def test_sanitize_postgres_text__empty_string() -> None:
"""Test that empty string is handled correctly."""
result = sanitize_postgres_text("")
assert result == ""
def test_sanitize_postgres_text__mixed_case() -> None:
"""Test text with mix of normal text, NUL bytes, and control characters."""
mixed_text = "PDF Text\x00with NUL\tbytes\nand\x01control\x08chars\rand normal text."
# \r should be preserved as it's a valid whitespace character
expected = "PDF Textwith NUL\tbytes\nandcontrolchars\rand normal text."
result = sanitize_postgres_text(mixed_text)
assert result == expected
def test_sanitize_postgres_text__multiple_nul_bytes() -> None:
"""Test that multiple consecutive NUL bytes are all removed."""
text_with_multiple_nuls = "Start\x00\x00\x00Middle\x00\x00End"
expected = "StartMiddleEnd"
result = sanitize_postgres_text(text_with_multiple_nuls)
assert result == expected
def test_sanitize_postgres_text__unicode_text() -> None:
"""Test that Unicode characters are preserved."""
unicode_text = "中文测试 Unicode: café, naïve, Ω, emoji 😀"
result = sanitize_postgres_text(unicode_text)
assert result == unicode_text
def test_sanitize_postgres_text__real_world_pdf_scenario() -> None:
"""Test a realistic scenario with PDF extraction artifacts."""
# Simulate what might come from a PDF extraction
pdf_text = "Invoice\x00Number:\t12345\nDate:\t2024-01-01\x00\nTotal:\t$100.00\x01\x02"
expected = "InvoiceNumber:\t12345\nDate:\t2024-01-01\nTotal:\t$100.00"
result = sanitize_postgres_text(pdf_text)
assert result == expected
def test_sanitize_postgres_text__only_control_characters() -> None:
"""Test string with only problematic characters."""
only_controls = "\x00\x01\x02\x03\x08"
expected = ""
result = sanitize_postgres_text(only_controls)
assert result == expected
def test_sanitize_postgres_text__preserves_spaces_and_punctuation() -> None:
"""Test that normal spaces and punctuation are preserved."""
text = "Hello, World! How are you? I'm fine. Test@example.com"
result = sanitize_postgres_text(text)
assert result == text
def test_sanitize_postgres_text__newlines_and_paragraphs() -> None:
"""Test multi-paragraph text with newlines."""
multiline_text = "Paragraph 1\n\nParagraph 2\n\nParagraph 3"
result = sanitize_postgres_text(multiline_text)
assert result == multiline_text
assert result.count("\n") == 4

View File

@@ -0,0 +1,756 @@
"""
Tests for script generation race condition (SKY-7653).
The race condition occurs when script generation runs during workflow execution
before all actions have been saved to the database. This results in:
1. `generate_workflow_parameters_schema` not finding INPUT_TEXT actions
2. No field_name mappings being generated
3. Generated script having hardcoded values instead of context.parameters[field_name]
"""
from typing import Any
import pytest
from skyvern.core.script_generations import generate_workflow_parameters as gwp
from skyvern.core.script_generations.generate_workflow_parameters import (
CUSTOM_FIELD_ACTIONS,
GeneratedFieldMapping,
generate_workflow_parameters_schema,
hydrate_input_text_actions_with_field_names,
)
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.workflow.service import BLOCK_TYPES_THAT_SHOULD_BE_CACHED
from skyvern.webeye.actions.actions import ActionType
def make_input_text_action(
task_id: str,
action_id: str,
text: str,
intention: str = "",
field_name: str | None = None,
) -> dict[str, Any]:
"""Create a mock INPUT_TEXT action dictionary."""
action = {
"action_type": ActionType.INPUT_TEXT,
"action_id": action_id,
"task_id": task_id,
"text": text,
"intention": intention,
"element_id": "element_1",
"xpath": "//input[@id='test']",
}
if field_name:
action["field_name"] = field_name
return action
def make_click_action(task_id: str, action_id: str) -> dict[str, Any]:
"""Create a mock CLICK action dictionary."""
return {
"action_type": ActionType.CLICK,
"action_id": action_id,
"task_id": task_id,
"element_id": "element_2",
"xpath": "//button[@id='submit']",
}
class TestRaceConditionScenarios:
"""Test scenarios that demonstrate the race condition."""
def test_hydrate_adds_field_name_to_actions(self) -> None:
"""Test that hydrate_input_text_actions_with_field_names properly adds field_name."""
task_id = "task-123"
action_id = "action-456"
actions_by_task = {
task_id: [
make_input_text_action(task_id, action_id, "Urdaneta", "Enter facility name"),
]
}
field_mappings = {
f"{task_id}:{action_id}": "facility_name",
}
result = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
# The action should now have field_name
assert result[task_id][0].get("field_name") == "facility_name"
def test_hydrate_without_mappings_no_field_name(self) -> None:
"""
Test that without field mappings, actions don't get field_name added.
This simulates what happens when script generation runs before actions are saved.
"""
task_id = "task-123"
action_id = "action-456"
actions_by_task = {
task_id: [
make_input_text_action(task_id, action_id, "Urdaneta", "Enter facility name"),
]
}
# Empty field mappings - simulates race condition where LLM wasn't called
# because no INPUT_TEXT actions were found
field_mappings: dict[str, str] = {}
result = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
# The action should NOT have field_name
assert "field_name" not in result[task_id][0]
def test_race_condition_empty_actions_produces_empty_schema(self) -> None:
"""
Test that when no actions are passed, generate_workflow_parameters_schema
returns an empty schema. This happens when script generation runs before
actions are executed.
"""
# Empty actions - simulates script generation running before any INPUT_TEXT
# actions have been saved to the database
actions_by_task: dict[str, list[dict[str, Any]]] = {}
# Call the synchronous part that checks for actions
# (The async LLM call won't be made because no actions are found)
# Extract just the action-finding logic
custom_field_actions = []
for task_id, actions in actions_by_task.items():
for action in actions:
action_type = action.get("action_type", "")
if action_type in CUSTOM_FIELD_ACTIONS:
custom_field_actions.append(action)
# With no actions, the schema generator should return empty schema
assert len(custom_field_actions) == 0
def test_race_condition_only_click_actions_no_schema(self) -> None:
"""
Test that when only CLICK actions are present (before INPUT_TEXT is saved),
no field mappings are generated.
"""
task_id = "task-123"
# Only CLICK actions - simulates script generation running after CLICK
# but before INPUT_TEXT action is saved
actions_by_task = {
task_id: [
make_click_action(task_id, "action-1"),
make_click_action(task_id, "action-2"),
]
}
custom_field_actions = []
for task_id, actions in actions_by_task.items():
for action in actions:
action_type = action.get("action_type", "")
if action_type in CUSTOM_FIELD_ACTIONS:
custom_field_actions.append(action)
# No INPUT_TEXT actions found - no schema will be generated
assert len(custom_field_actions) == 0
class TestCodeGenerationWithoutFieldName:
"""
Test that code generation produces hardcoded values when field_name is missing.
This demonstrates the impact of the race condition on generated code.
"""
def test_action_without_field_name_produces_hardcoded_value(self) -> None:
"""
When an INPUT_TEXT action doesn't have field_name (due to race condition),
the generated code should have a hardcoded value instead of context.parameters.
"""
action = make_input_text_action(
task_id="task-123",
action_id="action-456",
text="Urdaneta", # This becomes hardcoded
intention="Enter facility name",
field_name=None, # No field_name due to race condition
)
# The action_handler_body function uses act.get("field_name") to decide
# whether to use context.parameters[field_name] or hardcoded value
assert action.get("field_name") is None
assert action.get("text") == "Urdaneta" # Will be hardcoded
def test_action_with_field_name_produces_parameter_reference(self) -> None:
"""
When an INPUT_TEXT action has field_name, the generated code should
use context.parameters[field_name].
"""
action = make_input_text_action(
task_id="task-123",
action_id="action-456",
text="Urdaneta", # Original value (not used in generated code)
intention="Enter facility name",
field_name="facility_name", # Field name present
)
# The action has field_name, so generated code will use context.parameters
assert action.get("field_name") == "facility_name"
class TestFieldMappingGeneration:
"""Test the field mapping generation logic."""
def test_field_mapping_structure(self) -> None:
"""Test that GeneratedFieldMapping has the expected structure."""
mapping = GeneratedFieldMapping(
field_mappings={"action_index_1": "facility_name"},
schema_fields={"facility_name": {"type": "str", "description": "The facility name"}},
)
assert mapping.field_mappings["action_index_1"] == "facility_name"
assert mapping.schema_fields["facility_name"]["type"] == "str"
def test_action_index_to_field_mapping_key_format(self) -> None:
"""Test that field mapping keys use the correct format: task_id:action_id."""
task_id = "task-123"
action_id = "action-456"
# This is the format used in generate_workflow_parameters_schema
expected_key = f"{task_id}:{action_id}"
assert expected_key == "task-123:action-456"
@pytest.mark.asyncio
async def test_generate_workflow_parameters_schema_empty_actions(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Integration test: Verify that empty actions result in empty schema.
This test confirms the race condition behavior - when script generation
runs before INPUT_TEXT actions are saved, no field mappings are generated.
"""
# Mock the prompt engine and LLM handler since we won't reach them
# (the function returns early when no custom_field_actions are found)
actions_by_task: dict[str, list[dict[str, Any]]] = {}
schema_code, action_field_mappings = await generate_workflow_parameters_schema(actions_by_task)
# Should return empty schema
assert "pass" in schema_code # Empty schema has `pass`
assert action_field_mappings == {}
@pytest.mark.asyncio
async def test_generate_workflow_parameters_schema_with_actions(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Integration test: Verify that when actions are present, LLM is called.
This confirms that when script generation runs AFTER actions are saved,
it properly generates field mappings.
"""
# Mock the LLM call to return a mapping
async def mock_generate_field_names_with_llm(custom_field_actions):
return GeneratedFieldMapping(
field_mappings={"action_index_1": "facility_name"},
schema_fields={"facility_name": {"type": "str", "description": "The facility name"}},
)
monkeypatch.setattr(gwp, "_generate_field_names_with_llm", mock_generate_field_names_with_llm)
task_id = "task-123"
action_id = "action-456"
actions_by_task = {
task_id: [
make_input_text_action(task_id, action_id, "Urdaneta", "Enter facility name"),
]
}
schema_code, action_field_mappings = await generate_workflow_parameters_schema(actions_by_task)
# Should have generated schema with field
assert "facility_name" in schema_code
assert "GeneratedWorkflowParameters" in schema_code
# Should have mapping for our action
assert f"{task_id}:{action_id}" in action_field_mappings
assert action_field_mappings[f"{task_id}:{action_id}"] == "facility_name"
class TestRaceConditionTimingScenario:
"""
Document the timing scenario that causes the race condition.
Timeline:
1. T+0s: CLICK action executes, post_action_execution triggered
2. T+0.1s: Script generation starts (asyncio.create_task)
3. T+0.2s: Script generation queries database for actions - finds only CLICK
4. T+0.3s: Script generation completes with no field mappings
5. T+6s: INPUT_TEXT action executes, saved to database
6. T+6.1s: Another script generation triggered, but first (wrong) script already saved
The result is a script with hardcoded values like `value = 'Urdaneta'`
instead of `value = context.parameters['facility_name']`
"""
def test_timing_scenario_documentation(self) -> None:
"""This test documents the race condition scenario."""
# Phase 1: After CLICK, before INPUT_TEXT
actions_at_time_0 = {
"task-123": [
make_click_action("task-123", "action-1"),
]
}
# At this point, script generation finds no INPUT_TEXT actions
input_text_actions = [
a for actions in actions_at_time_0.values() for a in actions if a["action_type"] == ActionType.INPUT_TEXT
]
assert len(input_text_actions) == 0
# Phase 2: After INPUT_TEXT is saved
actions_at_time_6 = {
"task-123": [
make_click_action("task-123", "action-1"),
make_input_text_action("task-123", "action-2", "Urdaneta", "Enter facility name"),
]
}
# Now INPUT_TEXT is found - but too late, first script already saved
input_text_actions = [
a for actions in actions_at_time_6.values() for a in actions if a["action_type"] == ActionType.INPUT_TEXT
]
assert len(input_text_actions) == 1
assert input_text_actions[0]["text"] == "Urdaneta"
class TestFinalizeParameter:
"""
Tests for the `finalize` parameter in generate_script_if_needed.
The fix (SKY-7653) uses a smart finalize approach:
- Only regenerates if script_gen_had_incomplete_actions flag is set
- This avoids unnecessary regeneration costs when script is already complete
"""
def test_finalize_with_incomplete_actions_triggers_regeneration(self) -> None:
"""
Test that finalize=True with incomplete actions flag triggers regeneration.
This simulates the logic in generate_script_if_needed when finalize=True
and the context has script_gen_had_incomplete_actions=True.
"""
# Simulate workflow definition blocks
class MockBlock:
def __init__(self, label: str, block_type: str):
self.label = label
self.block_type = block_type
workflow_blocks = [
MockBlock("login_step", "task"), # Should be in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
MockBlock("search_step", "task"),
MockBlock("wait_block", "wait"), # Should NOT be cached
]
# Simulate the finalize logic with incomplete actions flag
blocks_to_update: set[str] = set()
finalize = True
context = SkyvernContext(script_gen_had_incomplete_actions=True)
if finalize and context.script_gen_had_incomplete_actions:
task_block_labels = {
block.label
for block in workflow_blocks
if block.label and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
}
blocks_to_update.update(task_block_labels)
# Should include task blocks but not wait block
assert "login_step" in blocks_to_update
assert "search_step" in blocks_to_update
assert "wait_block" not in blocks_to_update
def test_finalize_without_incomplete_actions_skips_regeneration(self) -> None:
"""
Test that finalize=True without incomplete actions flag skips regeneration.
This is the optimization - when script generation had complete data,
we don't waste resources regenerating.
"""
blocks_to_update: set[str] = set()
finalize = True
context = SkyvernContext(script_gen_had_incomplete_actions=False)
if finalize and context.script_gen_had_incomplete_actions:
# This branch won't execute - no incomplete actions
blocks_to_update.add("some_block")
# No blocks should be added - script is already complete
assert len(blocks_to_update) == 0
def test_without_finalize_no_forced_regeneration(self) -> None:
"""
Test that without finalize=True, blocks are not force-added.
"""
blocks_to_update: set[str] = set()
finalize = False
# Without finalize, no blocks are force-added
if finalize:
# This branch won't execute
blocks_to_update.add("some_block")
assert len(blocks_to_update) == 0
class TestCodeGenerationLogic:
"""
Test the exact code generation logic from generate_script.py.
The code at generate_script.py:401-429 determines whether to use
context.parameters[field_name] or hardcoded text based on act.get("field_name").
"""
def test_code_generation_path_without_field_name(self) -> None:
"""
Verify the code generation path when field_name is missing.
From generate_script.py:401-429:
- If act.get("field_name") is truthy, use context.parameters[field_name]
- Else, use _value(act["text"]) which produces hardcoded string
"""
action = make_input_text_action(
task_id="task-123",
action_id="action-456",
text="Urdaneta",
intention="Enter facility name",
field_name=None,
)
# Simulate the code generation logic
if action.get("field_name"):
# This branch produces: context.parameters["facility_name"]
code_path = "context.parameters"
else:
# This branch produces: "Urdaneta" (hardcoded)
code_path = "hardcoded"
assert code_path == "hardcoded"
assert action.get("text") == "Urdaneta"
def test_code_generation_path_with_field_name(self) -> None:
"""
Verify the code generation path when field_name is present.
From generate_script.py:401-429:
- If act.get("field_name") is truthy, use context.parameters[field_name]
"""
action = make_input_text_action(
task_id="task-123",
action_id="action-456",
text="Urdaneta",
intention="Enter facility name",
field_name="facility_name",
)
# Simulate the code generation logic
if action.get("field_name"):
# This branch produces: context.parameters["facility_name"]
code_path = "context.parameters"
else:
# This branch produces: "Urdaneta" (hardcoded)
code_path = "hardcoded"
assert code_path == "context.parameters"
assert action.get("field_name") == "facility_name"
def test_demonstrates_race_condition_consequence(self) -> None:
"""
Demonstrate the consequence of the race condition.
When script generation runs before INPUT_TEXT action is saved:
1. generate_workflow_parameters_schema finds no INPUT_TEXT actions
2. No field mappings are generated
3. Actions don't get field_name hydrated
4. Generated script uses hardcoded values
This means the cached script CANNOT be reused with different parameters.
"""
# Scenario: First workflow run with "Urdaneta"
# Script generation ran early, field_name is missing
action_from_early_script_gen = make_input_text_action(
task_id="task-123",
action_id="action-456",
text="Urdaneta", # This gets hardcoded
field_name=None, # Missing due to race condition
)
# The generated code would be: value = "Urdaneta"
generated_code_has_hardcoded = action_from_early_script_gen.get("field_name") is None
assert generated_code_has_hardcoded
# Scenario: User runs workflow again with "Pok Pok" parameter
# But the cached script has: value = "Urdaneta" (hardcoded!)
# So the wrong value is used.
# Correct scenario: Script generation runs after all actions saved
action_from_proper_script_gen = make_input_text_action(
task_id="task-123",
action_id="action-456",
text="Urdaneta",
field_name="facility_name", # Present because script gen ran after action saved
)
# The generated code would be: value = context.parameters["facility_name"]
generated_code_uses_parameters = action_from_proper_script_gen.get("field_name") is not None
assert generated_code_uses_parameters
# Now when user runs with "Pok Pok", context.parameters["facility_name"] = "Pok Pok"
# And the correct value is used!
class TestSkipActionsWithoutData:
"""
Tests for the smart finalize approach that skips actions without data.
This addresses the race condition (SKY-7653) while avoiding unnecessary costs:
1. Skip actions without data during mid-run generation (avoids bad field mappings)
2. Set context flag when actions are skipped (script_gen_had_incomplete_actions)
3. At finalize, only regenerate if the flag is set (avoids unnecessary regeneration)
The benefit is:
- First run with race condition: flag set → regenerate at end → script complete
- Subsequent runs: script already complete → no regeneration needed
"""
def test_input_text_without_text_is_skipped(self) -> None:
"""Test that INPUT_TEXT actions without text are skipped during field mapping."""
task_id = "task-123"
# INPUT_TEXT action without text - simulates race condition
action_without_text = {
"action_type": ActionType.INPUT_TEXT,
"action_id": "action-456",
"task_id": task_id,
"text": "", # Empty - not yet saved
"intention": "Enter facility name",
}
# Simulate the filtering logic from generate_workflow_parameters_schema
custom_field_actions = []
for action in [action_without_text]:
action_type = action.get("action_type", "")
if action_type not in CUSTOM_FIELD_ACTIONS:
continue
value = ""
if action_type == ActionType.INPUT_TEXT:
value = action.get("text", "")
# Skip actions without data
if not value:
continue
custom_field_actions.append(action)
# Action should be skipped because text is empty
assert len(custom_field_actions) == 0
def test_input_text_with_text_is_included(self) -> None:
"""Test that INPUT_TEXT actions with text are included in field mapping."""
task_id = "task-123"
# INPUT_TEXT action with text - properly saved
action_with_text = {
"action_type": ActionType.INPUT_TEXT,
"action_id": "action-456",
"task_id": task_id,
"text": "Urdaneta", # Has value
"intention": "Enter facility name",
}
# Simulate the filtering logic
custom_field_actions = []
for action in [action_with_text]:
action_type = action.get("action_type", "")
if action_type not in CUSTOM_FIELD_ACTIONS:
continue
value = ""
if action_type == ActionType.INPUT_TEXT:
value = action.get("text", "")
# Skip actions without data
if not value:
continue
custom_field_actions.append(action)
# Action should be included because text has value
assert len(custom_field_actions) == 1
def test_select_option_without_option_is_skipped(self) -> None:
"""Test that SELECT_OPTION actions without option are skipped."""
task_id = "task-123"
action_without_option = {
"action_type": ActionType.SELECT_OPTION,
"action_id": "action-789",
"task_id": task_id,
"option": "", # Empty - not yet saved
}
custom_field_actions = []
for action in [action_without_option]:
action_type = action.get("action_type", "")
if action_type not in CUSTOM_FIELD_ACTIONS:
continue
value = ""
if action_type == ActionType.SELECT_OPTION:
value = action.get("option", "")
if not value:
continue
custom_field_actions.append(action)
assert len(custom_field_actions) == 0
def test_upload_file_without_file_url_is_skipped(self) -> None:
"""Test that UPLOAD_FILE actions without file_url are skipped."""
task_id = "task-123"
action_without_file = {
"action_type": ActionType.UPLOAD_FILE,
"action_id": "action-101",
"task_id": task_id,
"file_url": "", # Empty - not yet saved
}
custom_field_actions = []
for action in [action_without_file]:
action_type = action.get("action_type", "")
if action_type not in CUSTOM_FIELD_ACTIONS:
continue
value = ""
if action_type == ActionType.UPLOAD_FILE:
value = action.get("file_url", "")
if not value:
continue
custom_field_actions.append(action)
assert len(custom_field_actions) == 0
@pytest.mark.asyncio
async def test_generate_workflow_parameters_schema_skips_empty_actions_and_sets_flag(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Integration test: Verify that actions without data are skipped and the context flag is set.
This test confirms the smart finalize approach:
1. Incomplete actions are skipped mid-run (prevents bad field mappings)
2. Context flag is set (triggers finalize regeneration only when needed)
"""
# Set up context to track the flag
context = SkyvernContext()
skyvern_context.set(context)
# Mock the LLM call - should only be called if there are valid actions
llm_called = False
async def mock_generate_field_names_with_llm(custom_field_actions):
nonlocal llm_called
llm_called = True
return GeneratedFieldMapping(
field_mappings={"action_index_1": "facility_name"},
schema_fields={"facility_name": {"type": "str", "description": "The facility name"}},
)
monkeypatch.setattr(gwp, "_generate_field_names_with_llm", mock_generate_field_names_with_llm)
task_id = "task-123"
# Actions with empty values - simulates race condition
actions_by_task = {
task_id: [
{
"action_type": ActionType.INPUT_TEXT,
"action_id": "action-456",
"task_id": task_id,
"text": "", # Empty - not yet saved
"intention": "Enter facility name",
},
]
}
try:
schema_code, action_field_mappings = await generate_workflow_parameters_schema(actions_by_task)
# LLM should NOT be called because action was skipped
assert not llm_called
# Should return empty schema
assert "pass" in schema_code
assert action_field_mappings == {}
# Context flag should be set - triggers finalize regeneration
assert context.script_gen_had_incomplete_actions is True
finally:
skyvern_context.reset()
@pytest.mark.asyncio
async def test_generate_workflow_parameters_schema_with_complete_actions_no_flag(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Integration test: Verify that complete actions don't set the context flag.
When script generation has complete data, the flag should NOT be set,
which means finalize won't regenerate (saving costs).
"""
# Set up context to track the flag
context = SkyvernContext()
skyvern_context.set(context)
# Mock the LLM call
async def mock_generate_field_names_with_llm(custom_field_actions):
return GeneratedFieldMapping(
field_mappings={"action_index_1": "facility_name"},
schema_fields={"facility_name": {"type": "str", "description": "The facility name"}},
)
monkeypatch.setattr(gwp, "_generate_field_names_with_llm", mock_generate_field_names_with_llm)
task_id = "task-123"
# Actions with complete values - no race condition
actions_by_task = {
task_id: [
{
"action_type": ActionType.INPUT_TEXT,
"action_id": "action-456",
"task_id": task_id,
"text": "Urdaneta", # Has value - complete
"intention": "Enter facility name",
},
]
}
try:
schema_code, action_field_mappings = await generate_workflow_parameters_schema(actions_by_task)
# Should have generated schema
assert "facility_name" in schema_code
# Context flag should NOT be set - no regeneration needed
assert context.script_gen_had_incomplete_actions is False
finally:
skyvern_context.reset()

View File

@@ -0,0 +1,224 @@
"""
Unit tests for ScriptSkyvernPage, specifically testing _wait_for_page_ready_before_action.
This test file exists to prevent regressions like the AttributeError bug where
self._page was used instead of self.page (see PR #8425, SKY-7676).
"""
import inspect
import re
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.config import settings
from skyvern.core.script_generations.script_skyvern_page import ScriptSkyvernPage
def create_mock_page():
"""Create a mock Playwright Page object with required attributes."""
page = MagicMock()
page.url = "https://example.com"
# Required for Playwright Page base class
page._loop = MagicMock()
page._impl_obj = page
return page
@pytest.fixture
def mock_scraped_page():
"""Create a mock ScrapedPage object."""
scraped_page = MagicMock()
scraped_page._browser_state = MagicMock()
return scraped_page
@pytest.fixture
def mock_ai():
"""Create a mock SkyvernPageAi object."""
return MagicMock()
@pytest.mark.asyncio
async def test_wait_for_page_ready_before_action_calls_skyvern_frame(mock_scraped_page, mock_ai):
"""
Test that _wait_for_page_ready_before_action correctly calls SkyvernFrame.
This is a regression test for the bug in PR #8273 where self._page was used
instead of self.page, causing AttributeError because SkyvernPage stores the
Playwright page in self.page.
"""
mock_page = create_mock_page()
# Patch the Page base class to avoid Playwright internals
with patch(
"skyvern.core.script_generations.skyvern_page.Page.__init__",
return_value=None,
):
# Create ScriptSkyvernPage instance
script_page = ScriptSkyvernPage(
scraped_page=mock_scraped_page,
page=mock_page,
ai=mock_ai,
)
# Mock SkyvernFrame to verify it's called with self.page
mock_skyvern_frame = MagicMock()
mock_skyvern_frame.wait_for_page_ready = AsyncMock()
with patch(
"skyvern.core.script_generations.script_skyvern_page.SkyvernFrame.create_instance",
new_callable=AsyncMock,
return_value=mock_skyvern_frame,
) as mock_create_instance:
await script_page._wait_for_page_ready_before_action()
# Verify SkyvernFrame.create_instance was called exactly once
mock_create_instance.assert_called_once()
# Get the actual call argument
call_kwargs = mock_create_instance.call_args.kwargs
assert "frame" in call_kwargs, "create_instance should be called with frame argument"
# The frame argument should be a MagicMock (the page object)
assert call_kwargs["frame"] is not None, "frame should not be None"
# Verify wait_for_page_ready was called with correct settings
mock_skyvern_frame.wait_for_page_ready.assert_called_once_with(
network_idle_timeout_ms=settings.PAGE_READY_NETWORK_IDLE_TIMEOUT_MS,
loading_indicator_timeout_ms=settings.PAGE_READY_LOADING_INDICATOR_TIMEOUT_MS,
dom_stable_ms=settings.PAGE_READY_DOM_STABLE_MS,
dom_stability_timeout_ms=settings.PAGE_READY_DOM_STABILITY_TIMEOUT_MS,
)
@pytest.mark.asyncio
async def test_wait_for_page_ready_before_action_handles_no_page(mock_scraped_page, mock_ai):
"""
Test that _wait_for_page_ready_before_action returns early if self.page is None.
"""
# Patch the Page base class to avoid Playwright internals
with patch(
"skyvern.core.script_generations.skyvern_page.Page.__init__",
return_value=None,
):
# Create a mock page first, then set page to None after construction
mock_page = create_mock_page()
script_page = ScriptSkyvernPage(
scraped_page=mock_scraped_page,
page=mock_page,
ai=mock_ai,
)
# Simulate page being None (e.g., after page was closed)
script_page.page = None
# This should return early without raising an error
with patch(
"skyvern.core.script_generations.script_skyvern_page.SkyvernFrame.create_instance",
new_callable=AsyncMock,
) as mock_create_instance:
await script_page._wait_for_page_ready_before_action()
# SkyvernFrame.create_instance should NOT be called
mock_create_instance.assert_not_called()
@pytest.mark.asyncio
async def test_wait_for_page_ready_before_action_catches_exceptions(mock_scraped_page, mock_ai):
"""
Test that exceptions in _wait_for_page_ready_before_action are caught
and don't block action execution.
This verifies the defensive behavior - page readiness check failures
should not prevent actions from executing.
"""
mock_page = create_mock_page()
with patch(
"skyvern.core.script_generations.skyvern_page.Page.__init__",
return_value=None,
):
script_page = ScriptSkyvernPage(
scraped_page=mock_scraped_page,
page=mock_page,
ai=mock_ai,
)
# Make SkyvernFrame.create_instance raise an exception
with patch(
"skyvern.core.script_generations.script_skyvern_page.SkyvernFrame.create_instance",
new_callable=AsyncMock,
side_effect=Exception("Simulated page readiness error"),
):
# Should NOT raise - exception should be caught
await script_page._wait_for_page_ready_before_action()
@pytest.mark.asyncio
async def test_wait_for_page_ready_before_action_catches_wait_for_page_ready_exceptions(mock_scraped_page, mock_ai):
"""
Test that exceptions from wait_for_page_ready are caught and logged.
"""
mock_page = create_mock_page()
with patch(
"skyvern.core.script_generations.skyvern_page.Page.__init__",
return_value=None,
):
script_page = ScriptSkyvernPage(
scraped_page=mock_scraped_page,
page=mock_page,
ai=mock_ai,
)
# Make wait_for_page_ready raise an exception
mock_skyvern_frame = MagicMock()
mock_skyvern_frame.wait_for_page_ready = AsyncMock(side_effect=TimeoutError("Page never became idle"))
with patch(
"skyvern.core.script_generations.script_skyvern_page.SkyvernFrame.create_instance",
new_callable=AsyncMock,
return_value=mock_skyvern_frame,
):
# Should NOT raise - exception should be caught
await script_page._wait_for_page_ready_before_action()
@pytest.mark.asyncio
async def test_wait_for_page_ready_attribute_access_regression():
"""
Regression test: Verify that the code accesses self.page, not self._page.
The original bug (fixed in PR #8425) used self._page which caused:
AttributeError: 'ScriptSkyvernPage' object has no attribute '_page'
This test directly inspects the source code to ensure self._page is not used.
"""
source = inspect.getsource(ScriptSkyvernPage._wait_for_page_ready_before_action)
# The fixed code should use self.page
assert "self.page" in source, "Method should access self.page"
# The fixed code should NOT use self._page (except in comments)
# Remove comments and docstrings first
# Remove docstrings
source_no_docstrings = re.sub(r'""".*?"""', "", source, flags=re.DOTALL)
source_no_docstrings = re.sub(r"'''.*?'''", "", source_no_docstrings, flags=re.DOTALL)
# Remove single-line comments
source_no_comments = re.sub(r"#.*$", "", source_no_docstrings, flags=re.MULTILINE)
# Now check - self._page should NOT appear in the actual code
# (It may appear in comments explaining the fix, which is fine)
lines_with_code = [
line for line in source_no_comments.split("\n") if line.strip() and not line.strip().startswith("#")
]
code_only = "\n".join(lines_with_code)
# Check for the bug pattern
if "self._page" in code_only:
# Find the line for better error reporting
for i, line in enumerate(source.split("\n"), 1):
if "self._page" in line and not line.strip().startswith("#"):
pytest.fail(
f"Found 'self._page' in code at line {i}: {line.strip()}\n"
"This is a regression! SkyvernPage uses self.page, not self._page."
)

View File

@@ -0,0 +1,31 @@
import pytest
from skyvern.forge.sdk.schemas.credentials import (
CreateCredentialRequest,
CredentialType,
SecretCredential,
)
class TestSecretCredentialModels:
def test_secret_credential_creation(self) -> None:
cred = SecretCredential(secret_value="sk-abc123", secret_label="API Key")
assert cred.secret_value == "sk-abc123"
assert cred.secret_label == "API Key"
def test_secret_credential_optional_type(self) -> None:
cred = SecretCredential(secret_value="token123")
assert cred.secret_label is None
def test_non_empty_validation(self) -> None:
with pytest.raises(ValueError):
SecretCredential(secret_value="")
def test_create_request_with_secret(self) -> None:
req = CreateCredentialRequest(
name="My API Key",
credential_type=CredentialType.SECRET,
credential=SecretCredential(secret_value="sk-12345"),
)
assert req.credential_type == CredentialType.SECRET
assert req.credential.secret_value == "sk-12345"

View File

@@ -0,0 +1,13 @@
import pytest
from freezegun import freeze_time
from skyvern.forge.sdk.core.security import create_access_token, generate_skyvern_webhook_signature
@pytest.mark.skip(reason="Skipping test_generate_skyvern_signature")
@freeze_time("2023-11-30 00:00:00")
def test_generate_skyvern_signature() -> None:
api_key = create_access_token("o_12345")
payload = {"task_id": "t_12345", "float": 1.0}
signed_data = generate_skyvern_webhook_signature(payload, api_key)
assert signed_data.signature == "1fac4204e1abc7cb0bdf1a42eb17d27f6f1feba065d5726777d5eb77581298c1"

View File

@@ -0,0 +1,210 @@
import sys
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import pytest
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.workflow.models.block import TextPromptBlock
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
block_module = sys.modules["skyvern.forge.sdk.workflow.models.block"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
("model_name", "expected_llm_key"),
[
("gemini-2.5-flash", "VERTEX_GEMINI_2.5_FLASH"),
("gemini-3-pro-preview", "VERTEX_GEMINI_3.0_PRO"),
],
)
async def test_text_prompt_block_uses_selected_model(monkeypatch, model_name, expected_llm_key):
now = datetime.now(timezone.utc)
output_parameter = OutputParameter(
parameter_type=ParameterType.OUTPUT,
key="text_prompt_output",
description=None,
output_parameter_id="output-1",
workflow_id="workflow-1",
created_at=now,
modified_at=now,
deleted_at=None,
)
block = TextPromptBlock(
label="text-block",
llm_key="AZURE_OPENAI_GPT4_1",
prompt="Explain the status.",
parameters=[],
json_schema=None,
output_parameter=output_parameter,
model={"model_name": model_name},
)
captured: dict[str, str] = {}
fake_default_handler = AsyncMock()
async def fake_resolve_default_llm_handler(*args, **kwargs):
return fake_default_handler
async def fake_handler(*, prompt: str, prompt_name: str, **kwargs):
captured["prompt"] = prompt
captured["prompt_name"] = prompt_name
return {"llm_response": "ok"}
def fake_get_override_handler(llm_key: str | None, *, default):
captured["llm_key"] = llm_key if llm_key else "default"
return fake_handler if llm_key else default
block_module.app.LLM_API_HANDLER = fake_default_handler
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
monkeypatch.setattr(
LLMAPIHandlerFactory,
"get_override_llm_api_handler",
fake_get_override_handler,
raising=False,
)
monkeypatch.setattr(
TextPromptBlock,
"_resolve_default_llm_handler",
fake_resolve_default_llm_handler,
raising=False,
)
monkeypatch.setattr(
prompt_engine,
"load_prompt_from_string",
lambda template, **kwargs: template,
)
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
assert captured["llm_key"] == expected_llm_key
assert captured["prompt_name"] == "text-prompt"
assert response == {"llm_response": "ok"}
@pytest.mark.asyncio
async def test_text_prompt_block_uses_workflow_handler_when_no_override(monkeypatch):
now = datetime.now(timezone.utc)
output_parameter = OutputParameter(
parameter_type=ParameterType.OUTPUT,
key="text_prompt_output",
description=None,
output_parameter_id="output-2",
workflow_id="workflow-1",
created_at=now,
modified_at=now,
deleted_at=None,
)
block = TextPromptBlock(
label="text-block",
llm_key=None,
prompt="Summarize status.",
parameters=[],
json_schema=None,
output_parameter=output_parameter,
model=None,
)
captured: dict[str, str] = {}
fake_secondary_handler = AsyncMock(return_value={"llm_response": "secondary"})
async def fake_prompt_type_handler(*args, **kwargs):
return None
def fake_get_override_handler(llm_key: str | None, *, default):
captured["llm_key"] = llm_key if llm_key else "default"
captured["default_handler"] = default
return default
block_module.app.SECONDARY_LLM_API_HANDLER = fake_secondary_handler
block_module.app.LLM_API_HANDLER = AsyncMock()
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
monkeypatch.setattr(
LLMAPIHandlerFactory,
"get_override_llm_api_handler",
fake_get_override_handler,
raising=False,
)
monkeypatch.setattr(
block_module,
"get_llm_handler_for_prompt_type",
fake_prompt_type_handler,
raising=False,
)
monkeypatch.setattr(
prompt_engine,
"load_prompt_from_string",
lambda template, **kwargs: template,
)
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
assert captured["llm_key"] == "default"
assert captured["default_handler"] == fake_secondary_handler
fake_secondary_handler.assert_awaited_once()
assert response == {"llm_response": "secondary"}
@pytest.mark.asyncio
async def test_text_prompt_block_prefers_prompt_type_config_over_secondary(monkeypatch):
now = datetime.now(timezone.utc)
output_parameter = OutputParameter(
parameter_type=ParameterType.OUTPUT,
key="text_prompt_output",
description=None,
output_parameter_id="output-3",
workflow_id="workflow-1",
created_at=now,
modified_at=now,
deleted_at=None,
)
block = TextPromptBlock(
label="text-block",
llm_key=None,
prompt="Provide summary.",
parameters=[],
json_schema=None,
output_parameter=output_parameter,
model=None,
)
captured: dict[str, str] = {}
prompt_config_handler = AsyncMock(return_value={"llm_response": "config"})
async def fake_prompt_type_handler(*args, **kwargs):
return prompt_config_handler
def fake_get_override_handler(llm_key: str | None, *, default):
captured["default_handler"] = default
return default
block_module.app.SECONDARY_LLM_API_HANDLER = AsyncMock()
block_module.app.LLM_API_HANDLER = AsyncMock()
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
monkeypatch.setattr(
LLMAPIHandlerFactory,
"get_override_llm_api_handler",
fake_get_override_handler,
raising=False,
)
monkeypatch.setattr(
block_module,
"get_llm_handler_for_prompt_type",
fake_prompt_type_handler,
raising=False,
)
monkeypatch.setattr(
prompt_engine,
"load_prompt_from_string",
lambda template, **kwargs: template,
)
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
assert captured["default_handler"] == prompt_config_handler
prompt_config_handler.assert_awaited_once()
assert response == {"llm_response": "config"}

View File

@@ -0,0 +1,71 @@
from types import SimpleNamespace
import pytest
from skyvern.forge.sdk.schemas.credentials import CredentialVaultType
from skyvern.forge.sdk.workflow import context_manager as cm
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.block import TaskV2Block
@pytest.mark.asyncio
async def test_register_credential_parameter_uses_db_totp_identifier(monkeypatch: pytest.MonkeyPatch) -> None:
db_credential = SimpleNamespace(
credential_id="cred-1",
organization_id="org-1",
vault_type=CredentialVaultType.BITWARDEN,
totp_identifier="user@example.com",
)
class FakeCredential:
def __init__(self) -> None:
self.totp_identifier = None
self.totp = None
def model_dump(self) -> dict:
return {}
class FakeCredentialItem:
def __init__(self) -> None:
self.credential = FakeCredential()
class FakeCredentialService:
async def get_credential_item(self, _db_credential: object) -> FakeCredentialItem:
return FakeCredentialItem()
class FakeDatabase:
async def get_credential(self, credential_id: str, organization_id: str) -> object:
assert credential_id == "cred-1"
assert organization_id == "org-1"
return db_credential
fake_app = SimpleNamespace(
DATABASE=FakeDatabase(),
CREDENTIAL_VAULT_SERVICES={CredentialVaultType.BITWARDEN: FakeCredentialService()},
)
monkeypatch.setattr(cm, "app", fake_app)
context = WorkflowRunContext(
workflow_title="title",
workflow_id="wf-1",
workflow_permanent_id="wfp-1",
workflow_run_id="wr-1",
aws_client=SimpleNamespace(),
)
parameter = SimpleNamespace(key="credential_param")
organization = SimpleNamespace(organization_id="org-1")
await context._register_credential_parameter_value("cred-1", parameter, organization)
assert context.get_credential_totp_identifier("credential_param") == "user@example.com"
def test_task_v2_block_resolves_totp_identifier_from_context() -> None:
block = TaskV2Block.model_construct(totp_identifier=None)
workflow_run_context = SimpleNamespace(credential_totp_identifiers={"credential_param": "user@example.com"})
assert block._resolve_totp_identifier(workflow_run_context) == "user@example.com"
block_with_explicit_totp = TaskV2Block.model_construct(totp_identifier="provided@example.com")
assert block_with_explicit_totp._resolve_totp_identifier(workflow_run_context) == "provided@example.com"

View File

@@ -0,0 +1,29 @@
from skyvern.utils.url_validators import encode_url
def test_encode_url_basic():
"""Test basic URL encoding with simple path"""
url = "https://example.com/path with spaces"
expected = "https://example.com/path%20with%20spaces"
assert encode_url(url) == expected
def test_encode_url_with_query_params():
"""Test URL encoding with query parameters"""
url = "https://example.com/search?q=hello world&type=test"
expected = "https://example.com/search?q=hello%20world&type=test"
assert encode_url(url) == expected
def test_encode_url_with_special_chars():
"""Test URL encoding with special characters"""
url = "https://example.com/path/with/special#chars?param=value&other=test@123"
expected = "https://example.com/path/with/special#chars?param=value&other=test@123"
assert encode_url(url) == expected
def test_encode_url_with_pre_encoded_chars():
"""Test URL encoding with pre-encoded characters in query parameters"""
url = "https://example.com/search?q=hello world&type=test%20test"
expected = "https://example.com/search?q=hello%20world&type=test%20test"
assert encode_url(url) == expected

View File

@@ -0,0 +1,48 @@
import re
import pytest
from skyvern.utils.templating import Constants, get_missing_variables
@pytest.mark.parametrize(
"template,data,expected",
[
("", {}, set()),
("Hello {{ name }}", {"name": "World"}, set()),
("Hello {{ name }}", {"age": 30}, {"name"}),
("{{ one }}", {"one": 1, "two": 2}, set()), # extra vars allowed
# nested (dotted) variables
("{{ user.name }}", {"user": {"name": "Alice"}}, set()),
("{{ user.name }}", {"user": {"age": 30}}, {"user.name"}),
# list access
("{{ items[0] }}", {}, {"items"}),
("{{ items[0] }}", {"items": [1, 2, 3]}, set()),
("{{ items[0] }}", {"items": []}, {"items[0]"}),
# deeply nested lists and dicts
("{{ data.users[0].name }}", {"data": {"users": [{"name": "Bob"}]}}, set()),
("{{ data.users[0].name }}", {"data": {"users": [{}]}}, {"data.users[0].name"}),
("{{ data.users[0].name }}", {"data": {}}, {"data.users[0].name"}),
],
)
def test_get_missing_variables(template, data, expected):
missing_vars = get_missing_variables(template, data)
assert missing_vars == expected
@pytest.mark.parametrize(
"template,expected",
[
("{{ var }}", {"var"}),
("{{ var.attr }}", {"var.attr"}),
("{{ var[0] }}", {"var[0]"}),
("{{ var['key'] }}", {"var['key']"}),
('{{ var["key"] }}', {'var["key"]'}),
("{{ var.attr[0] }}", {"var.attr[0]"}),
("No variables here", set()),
("{{ var1 }} and {{ var2.attr }}", {"var1", "var2.attr"}),
],
)
def test_regex_missing_variable_pattern(template, expected):
matches = set(re.findall(Constants.MissingVariablePattern, template))
assert matches == expected

View File

@@ -0,0 +1,176 @@
"""
Tests for Vertex AI cache model name extraction from LLMRouterConfig.
This tests the fix for the issue where GEMINI_3_0_FLASH_WITH_FALLBACK was
incorrectly using 'gemini-3.0-flash' instead of 'gemini-3-flash-preview'.
"""
import re
from dataclasses import dataclass
@dataclass
class MockLLMRouterModelConfig:
model_name: str
litellm_params: dict
@dataclass
class MockLLMRouterConfig:
model_name: str
model_list: list
main_model_group: str
required_env_vars: list = None
def __post_init__(self):
if self.required_env_vars is None:
self.required_env_vars = []
@dataclass
class MockLLMConfig:
model_name: str
required_env_vars: list = None
litellm_params: dict = None
def __post_init__(self):
if self.required_env_vars is None:
self.required_env_vars = []
class TestVertexCacheModelExtraction:
"""Test that model names are correctly extracted for Vertex AI caching."""
def _extract_model_name(self, llm_config, resolved_llm_key: str) -> str:
"""
Mimics the model name extraction logic from _create_vertex_cache_for_task.
"""
model_name = "gemini-2.5-flash" # Default
extracted_name = None
# For router configs (LLMRouterConfig), extract from model_list primary model FIRST
# This must be checked before model_name since router model_name is just an identifier
# (e.g., "gemini-3.0-flash-gpt-5-mini-fallback-router"), not an actual Vertex model
if hasattr(llm_config, "model_list") and hasattr(llm_config, "main_model_group"):
# Find the primary model in model_list by matching main_model_group
for model_entry in llm_config.model_list:
if model_entry.model_name == llm_config.main_model_group:
# Extract actual model name from litellm_params
model_param = model_entry.litellm_params.get("model", "")
if "vertex_ai/" in model_param:
extracted_name = model_param.split("/")[-1]
elif model_param.startswith("gemini-"):
extracted_name = model_param
break
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
if not extracted_name and hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str):
if "vertex_ai/" in llm_config.model_name:
# Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash"
extracted_name = llm_config.model_name.split("/")[-1]
elif llm_config.model_name.startswith("gemini-"):
# Already in correct format
extracted_name = llm_config.model_name
# For router/fallback configs, extract from api_base or infer from key name
if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params:
params = llm_config.litellm_params
api_base = params.get("api_base") if isinstance(params, dict) else getattr(params, "api_base", None)
if api_base and isinstance(api_base, str) and "/models/" in api_base:
# Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash"
extracted_name = api_base.split("/models/")[-1]
# For router configs without api_base, infer from the llm_key itself
if not extracted_name:
# Extract version from llm_key
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", resolved_llm_key, re.IGNORECASE)
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
# Determine flavor
if "_PRO_" in resolved_llm_key or resolved_llm_key.endswith("_PRO"):
extracted_name = f"gemini-{version}-pro"
elif "_FLASH_LITE_" in resolved_llm_key or resolved_llm_key.endswith("_FLASH_LITE"):
extracted_name = f"gemini-{version}-flash-lite"
else:
# Default to flash flavor
extracted_name = f"gemini-{version}-flash"
if extracted_name:
model_name = extracted_name
# Normalize model name to the canonical Vertex identifier
# Preserve preview suffixes so we don't strip required identifiers (e.g., gemini-3-flash-preview).
match = re.search(r"(gemini-\d+(?:\.\d+)?-(?:flash-lite|flash|pro)(?:-preview)?)", model_name, re.IGNORECASE)
if match:
model_name = match.group(1).lower()
return model_name
def test_router_config_extracts_gemini_3_flash_preview(self):
"""
GEMINI_3_0_FLASH_WITH_FALLBACK should extract 'gemini-3-flash-preview',
NOT 'gemini-3.0-flash'.
"""
# Create a mock router config that matches the real GEMINI_3_0_FLASH_WITH_FALLBACK
router_config = MockLLMRouterConfig(
model_name="gemini-3.0-flash-gpt-5-mini-fallback-router",
model_list=[
MockLLMRouterModelConfig(
model_name="vertex-gemini-3.0-flash",
litellm_params={"model": "vertex_ai/gemini-3-flash-preview"},
),
MockLLMRouterModelConfig(
model_name="gpt-5-mini-fallback",
litellm_params={"model": "gpt-5-mini-2025-08-07"},
),
],
main_model_group="vertex-gemini-3.0-flash",
)
model_name = self._extract_model_name(router_config, "GEMINI_3_0_FLASH_WITH_FALLBACK")
# Should extract the correct model name with -preview suffix
assert model_name == "gemini-3-flash-preview", (
f"Expected 'gemini-3-flash-preview' but got '{model_name}'. "
"The router config should extract from model_list, not infer from llm_key."
)
def test_direct_vertex_config_extracts_correctly(self):
"""Direct VERTEX_GEMINI_3.0_FLASH should extract correctly."""
direct_config = MockLLMConfig(
model_name="vertex_ai/gemini-3-flash-preview",
)
model_name = self._extract_model_name(direct_config, "VERTEX_GEMINI_3.0_FLASH")
assert model_name == "gemini-3-flash-preview"
def test_router_config_extracts_gemini_2_5_flash(self):
"""GEMINI_2_5_FLASH_WITH_FALLBACK should extract 'gemini-2.5-flash'."""
router_config = MockLLMRouterConfig(
model_name="gemini-2.5-flash-gpt-5-mini-fallback-router",
model_list=[
MockLLMRouterModelConfig(
model_name="vertex-gemini-2.5-flash",
litellm_params={"model": "vertex_ai/gemini-2.5-flash"},
),
MockLLMRouterModelConfig(
model_name="gpt-5-mini-fallback",
litellm_params={"model": "gpt-5-mini-2025-08-07"},
),
],
main_model_group="vertex-gemini-2.5-flash",
)
model_name = self._extract_model_name(router_config, "GEMINI_2_5_FLASH_WITH_FALLBACK")
assert model_name == "gemini-2.5-flash"
def test_fallback_to_llm_key_inference_when_no_model_list(self):
"""When there's no model_list, should fall back to llm_key inference."""
# A config that doesn't have model_list (not a router config)
simple_config = MockLLMConfig(
model_name="some-unrelated-name",
)
model_name = self._extract_model_name(simple_config, "GEMINI_2_5_FLASH")
# Should fall back to inference from llm_key
assert model_name == "gemini-2.5-flash"

View File

@@ -0,0 +1,627 @@
"""Tests for workflow parameter key and block label validation.
These tests ensure that parameter keys and block labels are valid Python/Jinja2 identifiers,
preventing runtime errors like "'State_' is undefined" when using keys like "State_/_Province".
"""
import pytest
from pydantic import ValidationError
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameterType
from skyvern.schemas.workflows import (
TaskBlockYAML,
WorkflowParameterYAML,
sanitize_block_label,
sanitize_parameter_key,
sanitize_workflow_yaml_with_references,
)
from skyvern.utils.templating import replace_jinja_reference
class TestParameterKeyValidation:
"""Tests for parameter key validation."""
def test_valid_parameter_key_simple(self) -> None:
"""Test that simple valid keys are accepted."""
param = WorkflowParameterYAML(
key="my_parameter",
workflow_parameter_type=WorkflowParameterType.STRING,
)
assert param.key == "my_parameter"
def test_valid_parameter_key_with_numbers(self) -> None:
"""Test that keys with numbers (not at start) are accepted."""
param = WorkflowParameterYAML(
key="param123",
workflow_parameter_type=WorkflowParameterType.STRING,
)
assert param.key == "param123"
def test_valid_parameter_key_underscore_prefix(self) -> None:
"""Test that keys starting with underscore are accepted."""
param = WorkflowParameterYAML(
key="_private_param",
workflow_parameter_type=WorkflowParameterType.STRING,
)
assert param.key == "_private_param"
def test_valid_parameter_key_single_letter(self) -> None:
"""Test that single letter keys are accepted."""
param = WorkflowParameterYAML(
key="x",
workflow_parameter_type=WorkflowParameterType.STRING,
)
assert param.key == "x"
def test_invalid_parameter_key_with_slash(self) -> None:
"""Test that keys with '/' are rejected (the main bug case from SKY-7356)."""
with pytest.raises(ValidationError) as exc_info:
WorkflowParameterYAML(
key="State_/_Province",
workflow_parameter_type=WorkflowParameterType.STRING,
)
error_msg = str(exc_info.value)
assert "not a valid parameter name" in error_msg
def test_invalid_parameter_key_with_hyphen(self) -> None:
"""Test that keys with '-' are rejected."""
with pytest.raises(ValidationError) as exc_info:
WorkflowParameterYAML(
key="state-or-province",
workflow_parameter_type=WorkflowParameterType.STRING,
)
error_msg = str(exc_info.value)
assert "not a valid parameter name" in error_msg
def test_invalid_parameter_key_with_dot(self) -> None:
"""Test that keys with '.' are rejected."""
with pytest.raises(ValidationError) as exc_info:
WorkflowParameterYAML(
key="some.property",
workflow_parameter_type=WorkflowParameterType.STRING,
)
error_msg = str(exc_info.value)
assert "not a valid parameter name" in error_msg
def test_invalid_parameter_key_starts_with_digit(self) -> None:
"""Test that keys starting with a digit are rejected."""
with pytest.raises(ValidationError) as exc_info:
WorkflowParameterYAML(
key="123param",
workflow_parameter_type=WorkflowParameterType.STRING,
)
error_msg = str(exc_info.value)
assert "not a valid parameter name" in error_msg
def test_invalid_parameter_key_with_space(self) -> None:
"""Test that keys with spaces are rejected."""
with pytest.raises(ValidationError) as exc_info:
WorkflowParameterYAML(
key="my parameter",
workflow_parameter_type=WorkflowParameterType.STRING,
)
error_msg = str(exc_info.value)
assert "whitespace" in error_msg
def test_invalid_parameter_key_with_tab(self) -> None:
"""Test that keys with tabs are rejected."""
with pytest.raises(ValidationError) as exc_info:
WorkflowParameterYAML(
key="my\tparameter",
workflow_parameter_type=WorkflowParameterType.STRING,
)
error_msg = str(exc_info.value)
assert "whitespace" in error_msg
def test_invalid_parameter_key_with_asterisk(self) -> None:
"""Test that keys with '*' are rejected."""
with pytest.raises(ValidationError) as exc_info:
WorkflowParameterYAML(
key="param*value",
workflow_parameter_type=WorkflowParameterType.STRING,
)
error_msg = str(exc_info.value)
assert "not a valid parameter name" in error_msg
class TestBlockLabelValidation:
"""Tests for block label validation."""
def test_valid_block_label_simple(self) -> None:
"""Test that simple valid labels are accepted."""
block = TaskBlockYAML(label="my_task", url="https://example.com")
assert block.label == "my_task"
def test_valid_block_label_with_numbers(self) -> None:
"""Test that labels with numbers (not at start) are accepted."""
block = TaskBlockYAML(label="task123", url="https://example.com")
assert block.label == "task123"
def test_valid_block_label_underscore_prefix(self) -> None:
"""Test that labels starting with underscore are accepted."""
block = TaskBlockYAML(label="_private_task", url="https://example.com")
assert block.label == "_private_task"
def test_invalid_block_label_with_slash(self) -> None:
"""Test that labels with '/' are rejected."""
with pytest.raises(ValidationError) as exc_info:
TaskBlockYAML(label="task/block", url="https://example.com")
error_msg = str(exc_info.value)
assert "not a valid label" in error_msg
def test_invalid_block_label_with_hyphen(self) -> None:
"""Test that labels with '-' are rejected."""
with pytest.raises(ValidationError) as exc_info:
TaskBlockYAML(label="task-block", url="https://example.com")
error_msg = str(exc_info.value)
assert "not a valid label" in error_msg
def test_invalid_block_label_starts_with_digit(self) -> None:
"""Test that labels starting with a digit are rejected."""
with pytest.raises(ValidationError) as exc_info:
TaskBlockYAML(label="123task", url="https://example.com")
error_msg = str(exc_info.value)
assert "not a valid label" in error_msg
def test_invalid_block_label_empty(self) -> None:
"""Test that empty labels are rejected."""
with pytest.raises(ValidationError) as exc_info:
TaskBlockYAML(label="", url="https://example.com")
error_msg = str(exc_info.value)
assert "empty" in error_msg.lower()
def test_invalid_block_label_whitespace_only(self) -> None:
"""Test that whitespace-only labels are rejected."""
with pytest.raises(ValidationError) as exc_info:
TaskBlockYAML(label=" ", url="https://example.com")
error_msg = str(exc_info.value)
assert "empty" in error_msg.lower()
def test_invalid_block_label_with_space(self) -> None:
"""Test that labels with spaces are rejected."""
with pytest.raises(ValidationError) as exc_info:
TaskBlockYAML(label="my task", url="https://example.com")
error_msg = str(exc_info.value)
assert "not a valid label" in error_msg
class TestSanitizeBlockLabel:
"""Tests for the sanitize_block_label function."""
def test_sanitize_slash(self) -> None:
"""Test that slashes are replaced with underscores."""
assert sanitize_block_label("State/Province") == "State_Province"
def test_sanitize_hyphen(self) -> None:
"""Test that hyphens are replaced with underscores."""
assert sanitize_block_label("my-block") == "my_block"
def test_sanitize_dot(self) -> None:
"""Test that dots are replaced with underscores."""
assert sanitize_block_label("block.name") == "block_name"
def test_sanitize_multiple_special_chars(self) -> None:
"""Test that multiple special characters are handled."""
assert sanitize_block_label("State_/_Province") == "State_Province"
def test_sanitize_consecutive_underscores(self) -> None:
"""Test that consecutive underscores are collapsed."""
assert sanitize_block_label("a__b___c") == "a_b_c"
def test_sanitize_leading_trailing_underscores(self) -> None:
"""Test that leading/trailing underscores are removed."""
assert sanitize_block_label("_my_block_") == "my_block"
def test_sanitize_digit_prefix(self) -> None:
"""Test that labels starting with digits get underscore prefix."""
assert sanitize_block_label("123abc") == "_123abc"
def test_sanitize_digit_prefix_after_strip(self) -> None:
"""Test that digit prefix is added after stripping underscores."""
assert sanitize_block_label("_123abc") == "_123abc"
def test_sanitize_all_invalid_chars(self) -> None:
"""Test that if all chars are invalid, default is returned."""
assert sanitize_block_label("///") == "block"
def test_sanitize_empty_string(self) -> None:
"""Test that empty string returns default."""
assert sanitize_block_label("") == "block"
def test_sanitize_valid_label_unchanged(self) -> None:
"""Test that valid labels are unchanged."""
assert sanitize_block_label("my_valid_label") == "my_valid_label"
def test_sanitize_spaces(self) -> None:
"""Test that spaces are replaced with underscores."""
assert sanitize_block_label("my block name") == "my_block_name"
class TestSanitizeParameterKey:
"""Tests for the sanitize_parameter_key function."""
def test_sanitize_slash(self) -> None:
"""Test that slashes are replaced with underscores."""
assert sanitize_parameter_key("State/Province") == "State_Province"
def test_sanitize_hyphen(self) -> None:
"""Test that hyphens are replaced with underscores."""
assert sanitize_parameter_key("my-param") == "my_param"
def test_sanitize_dot(self) -> None:
"""Test that dots are replaced with underscores."""
assert sanitize_parameter_key("param.name") == "param_name"
def test_sanitize_all_invalid_chars(self) -> None:
"""Test that if all chars are invalid, default is returned."""
assert sanitize_parameter_key("///") == "parameter"
def test_sanitize_empty_string(self) -> None:
"""Test that empty string returns default."""
assert sanitize_parameter_key("") == "parameter"
def test_sanitize_valid_key_unchanged(self) -> None:
"""Test that valid keys are unchanged."""
assert sanitize_parameter_key("my_valid_key") == "my_valid_key"
class TestReplaceJinjaReference:
"""Tests for the replace_jinja_reference function."""
def test_replace_simple_reference(self) -> None:
"""Test replacing a simple Jinja reference."""
text = "Value is {{ old_key }}"
result = replace_jinja_reference(text, "old_key", "new_key")
assert result == "Value is {{ new_key }}"
def test_replace_reference_no_spaces(self) -> None:
"""Test replacing a reference without spaces."""
text = "Value is {{old_key}}"
result = replace_jinja_reference(text, "old_key", "new_key")
assert result == "Value is {{new_key}}"
def test_replace_reference_with_attribute(self) -> None:
"""Test replacing a reference with attribute access."""
text = "Value is {{ old_key.field }}"
result = replace_jinja_reference(text, "old_key", "new_key")
assert result == "Value is {{ new_key.field }}"
def test_replace_reference_with_filter(self) -> None:
"""Test replacing a reference with filter."""
text = "Value is {{ old_key | default('') }}"
result = replace_jinja_reference(text, "old_key", "new_key")
assert result == "Value is {{ new_key | default('') }}"
def test_replace_reference_with_index(self) -> None:
"""Test replacing a reference with index access."""
text = "Value is {{ old_key[0] }}"
result = replace_jinja_reference(text, "old_key", "new_key")
assert result == "Value is {{ new_key[0] }}"
def test_replace_multiple_references(self) -> None:
"""Test replacing multiple occurrences."""
text = "{{ old_key }} and {{ old_key.field }}"
result = replace_jinja_reference(text, "old_key", "new_key")
assert result == "{{ new_key }} and {{ new_key.field }}"
def test_no_replace_partial_match(self) -> None:
"""Test that partial matches are not replaced."""
text = "{{ old_key_extended }}"
result = replace_jinja_reference(text, "old_key", "new_key")
assert result == "{{ old_key_extended }}"
def test_no_replace_different_key(self) -> None:
"""Test that different keys are not affected."""
text = "{{ other_key }}"
result = replace_jinja_reference(text, "old_key", "new_key")
assert result == "{{ other_key }}"
class TestSanitizeWorkflowYamlWithReferences:
"""Tests for the sanitize_workflow_yaml_with_references function."""
def test_sanitize_simple_block_label(self) -> None:
"""Test sanitizing a simple block label."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {"parameters": [], "blocks": [{"label": "State/Province", "block_type": "task"}]},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["blocks"][0]["label"] == "State_Province"
def test_sanitize_updates_output_references(self) -> None:
"""Test that output references are updated when label is sanitized."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [],
"blocks": [
{"label": "my-block", "block_type": "task"},
{
"label": "second_block",
"block_type": "task",
"navigation_goal": "Use {{ my-block_output }} value",
},
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["blocks"][0]["label"] == "my_block"
assert "{{ my_block_output }}" in result["workflow_definition"]["blocks"][1]["navigation_goal"]
def test_sanitize_updates_next_block_label(self) -> None:
"""Test that next_block_label is updated when label is sanitized."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [],
"blocks": [
{"label": "block-1", "block_type": "task", "next_block_label": "block-2"},
{"label": "block-2", "block_type": "task"},
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["blocks"][0]["label"] == "block_1"
assert result["workflow_definition"]["blocks"][0]["next_block_label"] == "block_2"
assert result["workflow_definition"]["blocks"][1]["label"] == "block_2"
def test_sanitize_updates_finally_block_label(self) -> None:
"""Test that finally_block_label is updated when referenced label is sanitized."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [],
"blocks": [{"label": "cleanup-block", "block_type": "task"}],
"finally_block_label": "cleanup-block",
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["blocks"][0]["label"] == "cleanup_block"
assert result["workflow_definition"]["finally_block_label"] == "cleanup_block"
def test_sanitize_nested_loop_blocks(self) -> None:
"""Test that nested blocks in for_loop are sanitized."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [],
"blocks": [
{
"label": "my_loop",
"block_type": "for_loop",
"loop_blocks": [{"label": "inner-block", "block_type": "task"}],
}
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["blocks"][0]["loop_blocks"][0]["label"] == "inner_block"
def test_sanitize_no_changes_needed(self) -> None:
"""Test that valid labels are unchanged."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {"parameters": [], "blocks": [{"label": "valid_label", "block_type": "task"}]},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["blocks"][0]["label"] == "valid_label"
def test_sanitize_empty_workflow_definition(self) -> None:
"""Test handling of missing workflow_definition."""
workflow_yaml = {"title": "Test Workflow"}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result == workflow_yaml
def test_sanitize_updates_parameter_references(self) -> None:
"""Test that parameter references are updated."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [
{"key": "my_param", "parameter_type": "context", "source_parameter_key": "block-1_output"}
],
"blocks": [{"label": "block-1", "block_type": "task"}],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["blocks"][0]["label"] == "block_1"
assert result["workflow_definition"]["parameters"][0]["source_parameter_key"] == "block_1_output"
def test_sanitize_parameter_key(self) -> None:
"""Test that parameter keys with invalid characters are sanitized."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [
{
"key": "State/Province",
"parameter_type": "workflow",
}
],
"blocks": [],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["parameters"][0]["key"] == "State_Province"
def test_sanitize_parameter_key_updates_jinja_references(self) -> None:
"""Test that Jinja references to sanitized parameter keys are updated."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [
{
"key": "user-input",
"parameter_type": "workflow",
}
],
"blocks": [
{"label": "my_task", "block_type": "task", "navigation_goal": "Enter {{ user-input }} in the form"}
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["parameters"][0]["key"] == "user_input"
assert "{{ user_input }}" in result["workflow_definition"]["blocks"][0]["navigation_goal"]
def test_sanitize_parameter_key_updates_parameter_keys_array(self) -> None:
"""Test that parameter_keys arrays in blocks are updated."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [
{
"key": "my-param",
"parameter_type": "workflow",
}
],
"blocks": [{"label": "my_task", "block_type": "task", "parameter_keys": ["my-param", "other_param"]}],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["parameters"][0]["key"] == "my_param"
assert result["workflow_definition"]["blocks"][0]["parameter_keys"] == ["my_param", "other_param"]
def test_sanitize_both_labels_and_parameter_keys(self) -> None:
"""Test that both block labels and parameter keys are sanitized together."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [
{
"key": "user/input",
"parameter_type": "workflow",
}
],
"blocks": [
{
"label": "task-1",
"block_type": "task",
"navigation_goal": "Use {{ user/input }} and {{ task-1_output }}",
}
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["parameters"][0]["key"] == "user_input"
assert result["workflow_definition"]["blocks"][0]["label"] == "task_1"
nav_goal = result["workflow_definition"]["blocks"][0]["navigation_goal"]
assert "{{ user_input }}" in nav_goal
assert "{{ task_1_output }}" in nav_goal
def test_sanitize_block_label_collision(self) -> None:
"""Test that block labels that sanitize to the same value get unique suffixes."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [],
"blocks": [
{"label": "state/province", "block_type": "task"},
{"label": "state-province", "block_type": "task"},
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
labels = [b["label"] for b in result["workflow_definition"]["blocks"]]
assert labels[0] == "state_province"
assert labels[1] == "state_province_2"
# Ensure they are unique
assert len(set(labels)) == len(labels)
def test_sanitize_parameter_key_collision(self) -> None:
"""Test that parameter keys that sanitize to the same value get unique suffixes."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [
{"key": "user/input", "parameter_type": "workflow"},
{"key": "user-input", "parameter_type": "workflow"},
],
"blocks": [
{
"label": "my_task",
"block_type": "task",
"navigation_goal": "Use {{ user/input }} and {{ user-input }}",
}
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
keys = [p["key"] for p in result["workflow_definition"]["parameters"]]
assert keys[0] == "user_input"
assert keys[1] == "user_input_2"
# Ensure references are updated correctly
nav_goal = result["workflow_definition"]["blocks"][0]["navigation_goal"]
assert "{{ user_input }}" in nav_goal
assert "{{ user_input_2 }}" in nav_goal
def test_sanitize_collision_with_existing_valid_label(self) -> None:
"""Test that sanitized labels don't collide with already-valid labels."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [],
"blocks": [
{"label": "my_block", "block_type": "task"},
{"label": "my-block", "block_type": "task"},
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
labels = [b["label"] for b in result["workflow_definition"]["blocks"]]
assert labels[0] == "my_block"
assert labels[1] == "my_block_2"
def test_sanitize_shorthand_block_label_references(self) -> None:
"""Test that shorthand block label references ({{ label }} without _output) are also updated."""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [],
"blocks": [
{
"label": "extract/block",
"block_type": "extraction",
},
{
"label": "send_block",
"block_type": "send_email",
# Both shorthand {{ label }} and full {{ label_output }} patterns
"body": "Data: {{ extract/block.extracted_information }} and {{ extract/block_output.status }}",
},
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
assert result["workflow_definition"]["blocks"][0]["label"] == "extract_block"
body = result["workflow_definition"]["blocks"][1]["body"]
# Shorthand reference should be updated
assert "{{ extract_block.extracted_information }}" in body
# Full _output reference should also be updated
assert "{{ extract_block_output.status }}" in body
def test_sanitize_label_shorthand_does_not_corrupt_output_ref(self) -> None:
"""Ensure shorthand label replacement does not corrupt _output references.
When a label like 'block-1' is sanitized to 'block_1', both the shorthand
{{ block-1 }} and output {{ block-1_output }} patterns must be updated
independently without the shorthand replacement corrupting the _output form.
"""
workflow_yaml = {
"title": "Test Workflow",
"workflow_definition": {
"parameters": [],
"blocks": [
{
"label": "block-1",
"block_type": "task",
"navigation_goal": "{{ block-1 }} and {{ block-1_output }}",
}
],
},
}
result = sanitize_workflow_yaml_with_references(workflow_yaml)
goal = result["workflow_definition"]["blocks"][0]["navigation_goal"]
assert "{{ block_1 }}" in goal
assert "{{ block_1_output }}" in goal

View File

@@ -0,0 +1,600 @@
"""
Tests for workflow schema field name preservation (SKY-7434).
When a workflow is regenerated (e.g., after adding a new block), the LLM should
preserve field names for unchanged blocks to prevent schema mismatches with
cached block code.
"""
from __future__ import annotations
import os
from unittest.mock import AsyncMock
import pytest
from dotenv import load_dotenv
from skyvern.core.script_generations.generate_script import (
ScriptBlockSource,
_build_existing_field_assignments,
)
from skyvern.core.script_generations.generate_workflow_parameters import (
generate_workflow_parameters_schema,
)
from skyvern.forge.forge_app_initializer import start_forge_app
from tests.unit.force_stub_app import start_forge_stub_app
# Load environment variables for real LLM tests
load_dotenv()
# Check if real LLM tests should run (set RUN_LLM_TESTS=1 to enable)
SKIP_LLM_TESTS = os.environ.get("RUN_LLM_TESTS", "0") != "1"
class TestBuildExistingFieldAssignments:
"""Test the helper function that builds existing field assignments from cached blocks."""
def test_returns_empty_dict_when_no_cached_blocks(self):
"""When there are no cached blocks, should return empty dict."""
blocks = [
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
]
actions_by_task = {
"task_1": [
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
]
}
cached_blocks: dict[str, ScriptBlockSource] = {}
updated_block_labels: set[str] = set()
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
assert result == {}
def test_returns_empty_dict_when_all_blocks_updated(self):
"""When all blocks are in updated_block_labels, should return empty dict."""
blocks = [
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
]
actions_by_task = {
"task_1": [
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
]
}
cached_blocks = {
"login_block": ScriptBlockSource(
label="login_block",
code="async def login_block(): ...",
run_signature=None,
workflow_run_id=None,
workflow_run_block_id=None,
input_fields=["username"],
)
}
updated_block_labels = {"login_block"} # Block is updated, should not preserve
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
assert result == {}
def test_preserves_field_names_for_unchanged_blocks(self):
"""Unchanged blocks with input_fields should have their field names preserved."""
blocks = [
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
]
actions_by_task = {
"task_1": [
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
{"action_type": "input_text", "text": "pass123", "intention": "Enter password"},
]
}
cached_blocks = {
"login_block": ScriptBlockSource(
label="login_block",
code="async def login_block(): ...",
run_signature=None,
workflow_run_id=None,
workflow_run_block_id=None,
input_fields=["user_full_name", "user_password"],
)
}
updated_block_labels: set[str] = set() # No blocks updated
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
# Action 1 -> user_full_name, Action 2 -> user_password
assert result == {1: "user_full_name", 2: "user_password"}
def test_preserves_fields_for_multiple_unchanged_blocks(self):
"""Multiple unchanged blocks should each have their fields preserved."""
blocks = [
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
{"block_type": "task", "label": "form_block", "task_id": "task_2"},
]
actions_by_task = {
"task_1": [
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
],
"task_2": [
{"action_type": "input_text", "text": "Acme Inc", "intention": "Enter company"},
],
}
cached_blocks = {
"login_block": ScriptBlockSource(
label="login_block",
code="...",
run_signature=None,
workflow_run_id=None,
workflow_run_block_id=None,
input_fields=["username"],
),
"form_block": ScriptBlockSource(
label="form_block",
code="...",
run_signature=None,
workflow_run_id=None,
workflow_run_block_id=None,
input_fields=["company_name"],
),
}
updated_block_labels: set[str] = set()
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
# Action 1 (task_1) -> username, Action 2 (task_2) -> company_name
assert result == {1: "username", 2: "company_name"}
def test_mixed_updated_and_unchanged_blocks(self):
"""Only unchanged blocks should have their fields preserved."""
blocks = [
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
{"block_type": "task", "label": "new_block", "task_id": "task_2"},
]
actions_by_task = {
"task_1": [
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
],
"task_2": [
{"action_type": "input_text", "text": "new value", "intention": "Enter something"},
],
}
cached_blocks = {
"login_block": ScriptBlockSource(
label="login_block",
code="...",
run_signature=None,
workflow_run_id=None,
workflow_run_block_id=None,
input_fields=["username"],
),
# new_block is not in cached_blocks (it's new)
}
updated_block_labels: set[str] = set()
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
# Only action 1 should be preserved, action 2 is from a new block
assert result == {1: "username"}
def test_skips_non_custom_field_actions(self):
"""Actions that aren't INPUT_TEXT, UPLOAD_FILE, or SELECT_OPTION should be skipped."""
blocks = [
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
]
actions_by_task = {
"task_1": [
{"action_type": "click", "intention": "Click button"}, # Not a custom field action
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
]
}
cached_blocks = {
"login_block": ScriptBlockSource(
label="login_block",
code="...",
run_signature=None,
workflow_run_id=None,
workflow_run_block_id=None,
input_fields=["username"], # Only one input field
)
}
updated_block_labels: set[str] = set()
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
# The click action is skipped, so input_text is action 1
assert result == {1: "username"}
class TestGenerateWorkflowParametersSchemaWithExistingFields:
"""Test that the LLM receives existing field names when generating schema."""
@pytest.fixture(autouse=True)
def setup_stub_app(self):
"""Set up stub app for all tests in this class."""
self.stub_app = start_forge_stub_app()
@pytest.mark.asyncio
async def test_llm_receives_existing_field_names_in_prompt(self):
"""The LLM should receive existing field names to preserve in the prompt."""
actions_by_task = {
"task_1": [
{"action_type": "input_text", "text": "john", "intention": "Enter username", "action_id": "act_1"},
{"action_type": "input_text", "text": "pass", "intention": "Enter password", "action_id": "act_2"},
],
"task_2": [
{"action_type": "input_text", "text": "new", "intention": "Enter new field", "action_id": "act_3"},
],
}
existing_field_assignments = {
1: "preserved_username",
2: "preserved_password",
# Action 3 has no existing field - needs new name
}
# Mock the LLM response
mock_llm_response = {
"field_mappings": {
"action_index_1": "preserved_username",
"action_index_2": "preserved_password",
"action_index_3": "new_field_name",
},
"schema_fields": {
"preserved_username": {"type": "str", "description": "Username"},
"preserved_password": {"type": "str", "description": "Password"},
"new_field_name": {"type": "str", "description": "New field"},
},
}
captured_prompt = {}
async def mock_llm_handler(prompt, prompt_name):
captured_prompt["prompt"] = prompt
captured_prompt["prompt_name"] = prompt_name
return mock_llm_response
self.stub_app.SCRIPT_GENERATION_LLM_API_HANDLER = AsyncMock(side_effect=mock_llm_handler)
schema_code, field_mappings = await generate_workflow_parameters_schema(
actions_by_task, existing_field_assignments
)
# Verify the prompt contains the existing field names
prompt = captured_prompt["prompt"]
assert "preserved_username" in prompt
assert "preserved_password" in prompt
assert "MUST PRESERVE" in prompt or "EXISTING FIELD NAME" in prompt
# Verify the returned field mappings include preserved names
assert field_mappings["task_1:act_1"] == "preserved_username"
assert field_mappings["task_1:act_2"] == "preserved_password"
assert field_mappings["task_2:act_3"] == "new_field_name"
@pytest.mark.asyncio
async def test_no_existing_fields_works_normally(self):
"""When there are no existing fields, schema generation should work normally."""
actions_by_task = {
"task_1": [
{"action_type": "input_text", "text": "john", "intention": "Enter username", "action_id": "act_1"},
],
}
existing_field_assignments: dict[int, str] = {} # No existing fields
mock_llm_response = {
"field_mappings": {
"action_index_1": "username",
},
"schema_fields": {
"username": {"type": "str", "description": "Username field"},
},
}
captured_prompt = {}
async def mock_llm_handler(prompt, prompt_name):
captured_prompt["prompt"] = prompt
return mock_llm_response
self.stub_app.SCRIPT_GENERATION_LLM_API_HANDLER = AsyncMock(side_effect=mock_llm_handler)
schema_code, field_mappings = await generate_workflow_parameters_schema(
actions_by_task, existing_field_assignments
)
# Should not contain preservation instructions when no existing fields
prompt = captured_prompt["prompt"]
# The CRITICAL rule only appears when has_existing_fields is True
assert "CRITICAL" not in prompt
# Should still return valid mappings
assert field_mappings["task_1:act_1"] == "username"
@pytest.mark.asyncio
async def test_schema_code_includes_preserved_field_names(self):
"""The generated schema code should include the preserved field names."""
actions_by_task = {
"task_1": [
{"action_type": "input_text", "text": "john", "intention": "Enter username", "action_id": "act_1"},
],
}
existing_field_assignments = {1: "user_full_name"}
mock_llm_response = {
"field_mappings": {
"action_index_1": "user_full_name",
},
"schema_fields": {
"user_full_name": {"type": "str", "description": "The user's full name"},
},
}
async def mock_llm_handler(prompt, prompt_name):
return mock_llm_response
self.stub_app.SCRIPT_GENERATION_LLM_API_HANDLER = AsyncMock(side_effect=mock_llm_handler)
schema_code, field_mappings = await generate_workflow_parameters_schema(
actions_by_task, existing_field_assignments
)
# Schema code should include the preserved field name
assert "user_full_name" in schema_code
assert "str" in schema_code
class TestEndToEndFieldPreservation:
"""
End-to-end test simulating the real scenario:
1. Workflow has a login block with cached code using field names
2. User adds a new block
3. Schema is regenerated
4. Login block's field names should be preserved
"""
@pytest.fixture(autouse=True)
def setup_stub_app(self):
"""Set up stub app for all tests in this class."""
self.stub_app = start_forge_stub_app()
@pytest.mark.asyncio
async def test_adding_new_block_preserves_existing_block_field_names(self):
"""
Simulates: User has workflow with login block, adds a new block.
The login block's field names should be preserved in the regenerated schema.
"""
# Existing blocks (login was already there)
blocks = [
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
{"block_type": "task", "label": "new_block", "task_id": "task_2"}, # Newly added
]
# Actions from both blocks
actions_by_task = {
"task_1": [
{
"action_type": "input_text",
"text": "john@example.com",
"intention": "Enter email",
"action_id": "act_1",
},
{"action_type": "input_text", "text": "secret123", "intention": "Enter password", "action_id": "act_2"},
],
"task_2": [
{
"action_type": "input_text",
"text": "Acme Inc",
"intention": "Enter company name",
"action_id": "act_3",
},
],
}
# Cached blocks - login_block has existing field names
cached_blocks = {
"login_block": ScriptBlockSource(
label="login_block",
code="""
@skyvern.cached(cache_key='login_block')
async def login_block(page: SkyvernPage, context: RunContext):
await page.fill(
selector='xpath=//input[@id="email"]',
value=context.parameters['user_email'],
)
await page.fill(
selector='xpath=//input[@id="password"]',
value=context.parameters['user_password'],
)
""",
run_signature="await skyvern.login(...)",
workflow_run_id="wr_123",
workflow_run_block_id="wrb_123",
input_fields=["user_email", "user_password"], # These must be preserved!
),
# new_block is not in cached_blocks - it's brand new
}
# Only the new block is "updated" (actually new)
updated_block_labels: set[str] = set() # login_block is NOT updated
# Step 1: Build existing field assignments
existing_field_assignments = _build_existing_field_assignments(
blocks, actions_by_task, cached_blocks, updated_block_labels
)
# Verify login block fields are identified for preservation
assert existing_field_assignments == {
1: "user_email",
2: "user_password",
# Action 3 has no existing field (new block)
}
# Step 2: Mock LLM that respects the preservation instructions
mock_llm_response = {
"field_mappings": {
"action_index_1": "user_email", # Preserved
"action_index_2": "user_password", # Preserved
"action_index_3": "company_name", # New field for new block
},
"schema_fields": {
"user_email": {"type": "str", "description": "User's email address"},
"user_password": {"type": "str", "description": "User's password"},
"company_name": {"type": "str", "description": "Company name"},
},
}
captured_prompt = {}
async def mock_llm_handler(prompt, prompt_name):
captured_prompt["prompt"] = prompt
return mock_llm_response
self.stub_app.SCRIPT_GENERATION_LLM_API_HANDLER = AsyncMock(side_effect=mock_llm_handler)
schema_code, field_mappings = await generate_workflow_parameters_schema(
actions_by_task, existing_field_assignments
)
# Verify the prompt contains preservation instructions
prompt = captured_prompt["prompt"]
assert "user_email" in prompt, "Prompt should contain existing field name 'user_email'"
assert "user_password" in prompt, "Prompt should contain existing field name 'user_password'"
assert "MUST PRESERVE" in prompt or "EXISTING FIELD NAME" in prompt
# Verify field mappings preserve the original names
assert field_mappings["task_1:act_1"] == "user_email", "Login block email field should be preserved"
assert field_mappings["task_1:act_2"] == "user_password", "Login block password field should be preserved"
assert field_mappings["task_2:act_3"] == "company_name", "New block should get new field name"
# Verify schema code contains preserved field names
assert "user_email" in schema_code
assert "user_password" in schema_code
assert "company_name" in schema_code
# The cached login block code references context.parameters['user_email']
# and context.parameters['user_password'], which now match the schema!
cached_code = cached_blocks["login_block"].code
assert "user_email" in cached_code
assert "user_password" in cached_code
@pytest.mark.skipif(SKIP_LLM_TESTS, reason="Real LLM test - set RUN_LLM_TESTS=1 to enable")
class TestRealLLMFieldPreservation:
"""
Integration tests that make actual LLM calls to verify field preservation.
These tests require environment variables to be set (via .env file):
- SCRIPT_GENERATION_LLM_KEY or SECONDARY_LLM_KEY
- Appropriate API keys for the LLM provider
Run these tests with:
RUN_LLM_TESTS=1 pytest tests/unit/test_workflow_schema_field_preservation.py::TestRealLLMFieldPreservation -v -s
Note: Skipped by default since they make real LLM calls (costs money).
"""
@pytest.fixture(scope="class", autouse=True)
def setup_real_app(self):
"""Set up the real Forge app for LLM calls."""
start_forge_app()
yield
@pytest.mark.asyncio
async def test_llm_preserves_existing_field_names(self):
"""
Test that a real LLM preserves field names when instructed to.
This test sends a prompt with existing field names marked as "MUST PRESERVE"
and verifies the LLM returns those exact names in the response.
"""
actions_by_task = {
"task_1": [
{
"action_type": "input_text",
"text": "john.doe@example.com",
"intention": "Enter the user's email address for login",
"action_id": "act_1",
},
{
"action_type": "input_text",
"text": "secretpassword123",
"intention": "Enter the user's password",
"action_id": "act_2",
},
],
"task_2": [
{
"action_type": "input_text",
"text": "Acme Corporation",
"intention": "Enter the company name",
"action_id": "act_3",
},
],
}
# These are the existing field names that MUST be preserved
# Using unique names to ensure the LLM doesn't accidentally match them
existing_field_assignments = {
1: "preserved_login_email_xyz",
2: "preserved_login_password_abc",
# Action 3 has no existing field - LLM should generate a new name
}
schema_code, field_mappings = await generate_workflow_parameters_schema(
actions_by_task, existing_field_assignments
)
# Verify the LLM preserved the exact field names we specified
assert field_mappings["task_1:act_1"] == "preserved_login_email_xyz", (
f"LLM should have preserved 'preserved_login_email_xyz' but got '{field_mappings.get('task_1:act_1')}'"
)
assert field_mappings["task_1:act_2"] == "preserved_login_password_abc", (
f"LLM should have preserved 'preserved_login_password_abc' but got '{field_mappings.get('task_1:act_2')}'"
)
# Verify action 3 got a new field name (not one of the preserved ones)
action_3_field = field_mappings.get("task_2:act_3")
assert action_3_field is not None, "LLM should have generated a field name for action 3"
assert action_3_field not in ["preserved_login_email_xyz", "preserved_login_password_abc"], (
f"Action 3 should have a new field name, not a preserved one. Got: {action_3_field}"
)
# Verify the schema code contains the preserved field names
assert "preserved_login_email_xyz" in schema_code, "Schema should contain preserved email field"
assert "preserved_login_password_abc" in schema_code, "Schema should contain preserved password field"
assert action_3_field in schema_code, f"Schema should contain new field '{action_3_field}'"
print("\n✅ LLM preserved field names correctly!")
print(" - Action 1: preserved_login_email_xyz ✓")
print(" - Action 2: preserved_login_password_abc ✓")
print(f" - Action 3: {action_3_field} (newly generated) ✓")
@pytest.mark.asyncio
async def test_llm_generates_all_new_names_when_no_existing_fields(self):
"""
Test that when there are no existing fields, the LLM generates appropriate new names.
This is a baseline test to ensure the LLM call works correctly.
"""
actions_by_task = {
"task_1": [
{
"action_type": "input_text",
"text": "test@example.com",
"intention": "Enter email address",
"action_id": "act_1",
},
],
}
# No existing field assignments
existing_field_assignments: dict[int, str] = {}
schema_code, field_mappings = await generate_workflow_parameters_schema(
actions_by_task, existing_field_assignments
)
# Verify we got a field mapping
assert "task_1:act_1" in field_mappings, "Should have a field mapping for the action"
field_name = field_mappings["task_1:act_1"]
assert field_name, "Field name should not be empty"
assert field_name in schema_code, f"Schema should contain the generated field name '{field_name}'"
print(f"\n✅ LLM generated new field name: {field_name}")

View File

@@ -0,0 +1,152 @@
"""
Tests for workflow cache invalidation logic (SKY-7016).
Verifies that changes to the model field (both at workflow settings level and block level)
do not trigger cache invalidation.
"""
from datetime import datetime, timezone
from skyvern.forge.sdk.workflow.models.block import BlockType, TaskBlock
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
from skyvern.forge.sdk.workflow.models.workflow import WorkflowDefinition
from skyvern.forge.sdk.workflow.service import _get_workflow_definition_core_data
def make_output_parameter(key: str) -> OutputParameter:
"""Create a test output parameter."""
return OutputParameter(
parameter_type=ParameterType.OUTPUT,
key=key,
description="Test output parameter",
output_parameter_id="test-output-id",
workflow_id="test-workflow-id",
created_at=datetime.now(timezone.utc),
modified_at=datetime.now(timezone.utc),
)
def make_task_block(label: str, model: dict | None = None) -> TaskBlock:
"""Create a test task block with optional model configuration."""
return TaskBlock(
label=label,
block_type=BlockType.TASK,
output_parameter=make_output_parameter(f"{label}_output"),
url="https://example.com",
title="Test Task",
navigation_goal="Complete the task",
model=model,
)
class TestCacheInvalidation:
"""Tests for the _get_workflow_definition_core_data function."""
def test_model_field_excluded_from_block_comparison(self) -> None:
"""
SKY-7016: Verify that block-level model changes don't trigger cache invalidation.
The model field should be excluded from the comparison data.
"""
# Create two identical blocks, differing only in the model field
block_without_model = make_task_block("task1", model=None)
block_with_model = make_task_block("task1", model={"model_name": "gpt-4o"})
# Create workflow definitions with these blocks
definition_without_model = WorkflowDefinition(
parameters=[],
blocks=[block_without_model],
)
definition_with_model = WorkflowDefinition(
parameters=[],
blocks=[block_with_model],
)
# Get the core data used for comparison
core_data_without = _get_workflow_definition_core_data(definition_without_model)
core_data_with = _get_workflow_definition_core_data(definition_with_model)
# The core data should be identical (model field excluded)
assert core_data_without == core_data_with, (
"Model field should be excluded from comparison. "
"Changing block-level model should not trigger cache invalidation."
)
def test_model_field_not_in_core_data(self) -> None:
"""Verify that the model field is completely removed from the core data."""
block = make_task_block("task1", model={"model_name": "claude-3-sonnet"})
definition = WorkflowDefinition(
parameters=[],
blocks=[block],
)
core_data = _get_workflow_definition_core_data(definition)
# Check that model is not present in any block
for block_data in core_data.get("blocks", []):
assert "model" not in block_data, "Model field should be removed from block data"
def test_other_block_changes_still_detected(self) -> None:
"""Verify that non-model block changes are still detected."""
# Create two blocks with different navigation goals
block1 = make_task_block("task1")
block1.navigation_goal = "Goal A"
block2 = make_task_block("task1")
block2.navigation_goal = "Goal B"
definition1 = WorkflowDefinition(parameters=[], blocks=[block1])
definition2 = WorkflowDefinition(parameters=[], blocks=[block2])
core_data1 = _get_workflow_definition_core_data(definition1)
core_data2 = _get_workflow_definition_core_data(definition2)
# These should be different (navigation_goal is not excluded)
assert core_data1 != core_data2, "Non-model changes should still be detected for cache invalidation"
def test_different_models_same_core_data(self) -> None:
"""Verify that switching between different models produces same core data."""
models = [
None,
{"model_name": "gpt-4o"},
{"model_name": "claude-3-opus"},
{"model_name": "gemini-pro", "extra_param": "value"},
]
definitions = []
for model in models:
block = make_task_block("task1", model=model)
definition = WorkflowDefinition(parameters=[], blocks=[block])
definitions.append(_get_workflow_definition_core_data(definition))
# All core data should be identical
for i in range(1, len(definitions)):
assert definitions[0] == definitions[i], (
f"Core data should be identical regardless of model. Definition 0 vs {i} differ."
)
def test_timestamps_excluded_from_comparison(self) -> None:
"""Verify that timestamps are properly excluded from comparison."""
# Create two blocks with different timestamps
block1 = make_task_block("task1")
block2 = make_task_block("task1")
# Simulate different timestamps by recreating output parameters
block2.output_parameter = OutputParameter(
parameter_type=ParameterType.OUTPUT,
key="task1_output",
description="Test output parameter",
output_parameter_id="different-output-id", # Different ID
workflow_id="different-workflow-id", # Different workflow ID
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), # Different timestamp
modified_at=datetime(2024, 6, 1, tzinfo=timezone.utc), # Different timestamp
)
definition1 = WorkflowDefinition(parameters=[], blocks=[block1])
definition2 = WorkflowDefinition(parameters=[], blocks=[block2])
core_data1 = _get_workflow_definition_core_data(definition1)
core_data2 = _get_workflow_definition_core_data(definition2)
# These should be identical (timestamps and IDs are excluded)
assert core_data1 == core_data2, "Timestamps and IDs should be excluded from comparison"

View File

@@ -0,0 +1,232 @@
"""
Tests for continue_on_failure behavior with caching.
Verifies that:
1. When a block with continue_on_failure=True fails, it's not cached (existing behavior)
2. When a cached block with continue_on_failure=True fails during cached execution,
it's marked for regeneration so the next run uses AI execution
"""
from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
from skyvern.forge.sdk.workflow.models.block import (
BlockResult,
BlockType,
NavigationBlock,
)
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
from skyvern.forge.sdk.workflow.service import BLOCK_TYPES_THAT_SHOULD_BE_CACHED
from skyvern.schemas.workflows import BlockStatus
def _output_parameter(key: str) -> OutputParameter:
now = datetime.now(UTC)
return OutputParameter(
output_parameter_id=f"{key}_id",
key=key,
workflow_id="wf",
created_at=now,
modified_at=now,
)
def _navigation_block(
label: str,
continue_on_failure: bool = False,
next_block_label: str | None = None,
) -> NavigationBlock:
return NavigationBlock(
url="https://example.com",
label=label,
title=label,
navigation_goal="goal",
output_parameter=_output_parameter(f"{label}_output"),
next_block_label=next_block_label,
continue_on_failure=continue_on_failure,
)
class TestContinueOnFailureWithCache:
"""Tests for cache invalidation when continue_on_failure blocks fail."""
def test_navigation_block_is_cacheable(self) -> None:
"""Verify NavigationBlock is in the cacheable block types."""
assert BlockType.NAVIGATION in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
def test_failed_block_without_continue_on_failure_not_added_to_update(self) -> None:
"""
Test that a failed block without continue_on_failure=True doesn't trigger
special cache invalidation logic (it would stop the workflow instead).
"""
block = _navigation_block("nav1", continue_on_failure=False)
blocks_to_update: set[str] = set()
script_blocks_by_label = {"nav1": MagicMock()} # Block is cached
# Simulate failed block result
result = BlockResult(
success=False,
failure_reason="Block failed",
output_parameter=block.output_parameter,
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id="wrb-1",
)
# The cache invalidation logic for continue_on_failure
# This simulates the condition from service.py
should_invalidate = (
block.label
and block.continue_on_failure
and result.status != BlockStatus.completed
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
and block.label in script_blocks_by_label
)
if should_invalidate:
blocks_to_update.add(block.label)
# Should NOT be in blocks_to_update because continue_on_failure=False
assert block.label not in blocks_to_update
def test_failed_block_with_continue_on_failure_and_cached_added_to_update(self) -> None:
"""
Test that a cached block with continue_on_failure=True that fails
is added to blocks_to_update for regeneration.
"""
block = _navigation_block("nav1", continue_on_failure=True)
blocks_to_update: set[str] = set()
script_blocks_by_label = {"nav1": MagicMock()} # Block is cached
# Simulate failed block result
result = BlockResult(
success=False,
failure_reason="Block failed",
output_parameter=block.output_parameter,
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id="wrb-1",
)
# The cache invalidation logic for continue_on_failure
should_invalidate = (
block.label
and block.continue_on_failure
and result.status != BlockStatus.completed
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
and block.label in script_blocks_by_label
)
if should_invalidate:
blocks_to_update.add(block.label)
# SHOULD be in blocks_to_update for regeneration
assert block.label in blocks_to_update
def test_failed_uncached_block_with_continue_on_failure_not_added_to_update(self) -> None:
"""
Test that an uncached block with continue_on_failure=True that fails
is NOT added to blocks_to_update (there's nothing to invalidate).
"""
block = _navigation_block("nav1", continue_on_failure=True)
blocks_to_update: set[str] = set()
script_blocks_by_label: dict = {} # Block is NOT cached
# Simulate failed block result
result = BlockResult(
success=False,
failure_reason="Block failed",
output_parameter=block.output_parameter,
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id="wrb-1",
)
# The cache invalidation logic for continue_on_failure
should_invalidate = (
block.label
and block.continue_on_failure
and result.status != BlockStatus.completed
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
and block.label in script_blocks_by_label
)
if should_invalidate:
blocks_to_update.add(block.label)
# Should NOT be in blocks_to_update - nothing to invalidate
assert block.label not in blocks_to_update
def test_successful_block_with_continue_on_failure_not_added_to_update_for_invalidation(self) -> None:
"""
Test that a successful cached block with continue_on_failure=True
is NOT added to blocks_to_update for invalidation.
"""
block = _navigation_block("nav1", continue_on_failure=True)
blocks_to_update: set[str] = set()
script_blocks_by_label = {"nav1": MagicMock()} # Block is cached
# Simulate successful block result
result = BlockResult(
success=True,
failure_reason=None,
output_parameter=block.output_parameter,
output_parameter_value={"result": "success"},
status=BlockStatus.completed,
workflow_run_block_id="wrb-1",
)
# The cache invalidation logic for continue_on_failure
should_invalidate = (
block.label
and block.continue_on_failure
and result.status != BlockStatus.completed
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
and block.label in script_blocks_by_label
)
if should_invalidate:
blocks_to_update.add(block.label)
# Should NOT be in blocks_to_update - block succeeded
assert block.label not in blocks_to_update
@pytest.mark.parametrize(
"status",
[BlockStatus.failed, BlockStatus.terminated, BlockStatus.timed_out],
)
def test_all_failure_statuses_trigger_cache_invalidation(self, status: BlockStatus) -> None:
"""
Test that all non-completed statuses (failed, terminated, timed_out)
trigger cache invalidation when continue_on_failure=True.
"""
block = _navigation_block("nav1", continue_on_failure=True)
blocks_to_update: set[str] = set()
script_blocks_by_label = {"nav1": MagicMock()} # Block is cached
# Simulate block result with the given status
result = BlockResult(
success=False,
failure_reason=f"Block {status.value}",
output_parameter=block.output_parameter,
output_parameter_value=None,
status=status,
workflow_run_block_id="wrb-1",
)
# The cache invalidation logic for continue_on_failure
should_invalidate = (
block.label
and block.continue_on_failure
and result.status != BlockStatus.completed
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
and block.label in script_blocks_by_label
)
if should_invalidate:
blocks_to_update.add(block.label)
# SHOULD be in blocks_to_update for all failure statuses
assert block.label in blocks_to_update, f"Status {status} should trigger cache invalidation"

View File

@@ -0,0 +1,249 @@
from datetime import UTC, datetime
from unittest.mock import AsyncMock
import pytest
from skyvern.forge import app
from skyvern.forge.sdk.workflow.exceptions import InvalidWorkflowDefinition
from skyvern.forge.sdk.workflow.models.block import (
BranchCondition,
ConditionalBlock,
ExtractionBlock,
JinjaBranchCriteria,
NavigationBlock,
PromptBranchCriteria,
)
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
from skyvern.forge.sdk.workflow.service import WorkflowService
from skyvern.schemas.workflows import BlockStatus
def _output_parameter(key: str) -> OutputParameter:
now = datetime.now(UTC)
return OutputParameter(
output_parameter_id=f"{key}_id",
key=key,
workflow_id="wf",
created_at=now,
modified_at=now,
)
def _navigation_block(label: str, next_block_label: str | None = None) -> NavigationBlock:
return NavigationBlock(
url="https://example.com",
label=label,
title=label,
navigation_goal="goal",
output_parameter=_output_parameter(f"{label}_output"),
next_block_label=next_block_label,
)
def _extraction_block(label: str, next_block_label: str | None = None) -> ExtractionBlock:
return ExtractionBlock(
url="https://example.com",
label=label,
title=label,
data_extraction_goal="extract data",
output_parameter=_output_parameter(f"{label}_output"),
next_block_label=next_block_label,
)
def _conditional_block(
label: str, branch_conditions: list[BranchCondition], next_block_label: str | None = None
) -> ConditionalBlock:
return ConditionalBlock(
label=label,
output_parameter=_output_parameter(f"{label}_output"),
branch_conditions=branch_conditions,
next_block_label=next_block_label,
)
class DummyContext:
def __init__(self, workflow_run_id: str) -> None:
self.blocks_metadata: dict[str, dict] = {}
self.values: dict[str, object] = {}
self.secrets: dict[str, object] = {}
self.parameters: dict[str, object] = {}
self.workflow_run_outputs: dict[str, object] = {}
self.include_secrets_in_templates = False
self.workflow_title = "test"
self.workflow_id = "wf"
self.workflow_permanent_id = "wf-perm"
self.workflow_run_id = workflow_run_id
def update_block_metadata(self, label: str, metadata: dict) -> None:
self.blocks_metadata[label] = metadata
def get_block_metadata(self, label: str | None) -> dict:
if label is None:
return {}
return self.blocks_metadata.get(label, {})
def mask_secrets_in_data(self, data: object) -> object:
"""Mock method - returns data as-is since no secrets in tests."""
return data
async def register_output_parameter_value_post_execution(self, parameter: OutputParameter, value: object) -> None: # noqa: ARG002
return None
def build_workflow_run_summary(self) -> dict:
return {}
def test_build_workflow_graph_infers_default_edges() -> None:
service = WorkflowService()
first = _navigation_block("first")
second = _navigation_block("second")
start_label, label_to_block, default_next_map = service._build_workflow_graph([first, second])
assert start_label == "first"
assert set(label_to_block.keys()) == {"first", "second"}
assert default_next_map["first"] == "second"
assert default_next_map["second"] is None
def test_build_workflow_graph_rejects_cycles() -> None:
service = WorkflowService()
first = _navigation_block("first", next_block_label="second")
second = _navigation_block("second", next_block_label="first")
with pytest.raises(InvalidWorkflowDefinition):
service._build_workflow_graph([first, second])
def test_build_workflow_graph_requires_single_root() -> None:
service = WorkflowService()
first = _navigation_block("first")
second = _navigation_block("second")
with pytest.raises(InvalidWorkflowDefinition):
service._build_workflow_graph([first, second, _navigation_block("third", next_block_label="second")])
def test_build_workflow_graph_conditional_blocks_no_sequential_defaulting() -> None:
"""
Test that workflows with conditional blocks do not apply sequential defaulting.
This prevents cycles when blocks are ordered differently than execution order.
For example, if a terminal block appears before branch targets in the blocks array,
sequential defaulting would incorrectly create a cycle.
"""
service = WorkflowService()
# Simulate a workflow where execution order differs from block array order
# Execution: start -> extract -> conditional -> (branch_a OR branch_b) -> terminal
# Array order: [start, extract, conditional, terminal, branch_a, branch_b]
start = _navigation_block("start", next_block_label="extract")
extract = _extraction_block("extract", next_block_label="conditional")
conditional = _conditional_block(
"conditional",
branch_conditions=[
BranchCondition(
criteria=JinjaBranchCriteria(expression="{{ true }}"), next_block_label="branch_a", is_default=False
),
BranchCondition(criteria=None, next_block_label="branch_b", is_default=True),
],
next_block_label="terminal", # This should be ignored for conditional blocks
)
terminal = _extraction_block("terminal", next_block_label=None) # Terminal block with explicit None
branch_a = _navigation_block("branch_a", next_block_label="terminal")
branch_b = _navigation_block("branch_b", next_block_label="terminal")
# Block array has terminal before branch_a and branch_b
blocks = [start, extract, conditional, terminal, branch_a, branch_b]
# This should succeed without creating a cycle
start_label, label_to_block, default_next_map = service._build_workflow_graph(blocks)
assert start_label == "start"
assert set(label_to_block.keys()) == {"start", "extract", "conditional", "terminal", "branch_a", "branch_b"}
# Verify that sequential defaulting was NOT applied
# terminal should remain None, not be defaulted to branch_a
assert default_next_map["terminal"] is None
assert default_next_map["branch_a"] == "terminal"
assert default_next_map["branch_b"] == "terminal"
@pytest.mark.asyncio
async def test_evaluate_conditional_block_records_branch_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
output_param = _output_parameter("conditional_output")
block = ConditionalBlock(
label="cond",
output_parameter=output_param,
branch_conditions=[
BranchCondition(criteria=JinjaBranchCriteria(expression="{{ flag }}"), next_block_label="next"),
BranchCondition(is_default=True, next_block_label=None),
],
)
ctx = DummyContext(workflow_run_id="run-1")
ctx.values["flag"] = True
monkeypatch.setattr(app.WORKFLOW_CONTEXT_MANAGER, "get_workflow_run_context", lambda workflow_run_id: ctx)
app.DATABASE.update_workflow_run_block.reset_mock()
app.DATABASE.create_or_update_workflow_run_output_parameter.reset_mock()
result = await block.execute(
workflow_run_id="run-1",
workflow_run_block_id="wrb-1",
organization_id="org-1",
)
metadata = result.output_parameter_value
assert metadata["branch_taken"] == "next"
assert metadata["next_block_label"] == "next"
assert result.status == BlockStatus.completed
assert ctx.blocks_metadata["cond"]["branch_taken"] == "next"
# Get the actual call arguments
call_args = app.DATABASE.update_workflow_run_block.call_args
assert call_args.kwargs["workflow_run_block_id"] == "wrb-1"
assert call_args.kwargs["output"] == metadata
assert call_args.kwargs["status"] == BlockStatus.completed
assert call_args.kwargs["failure_reason"] is None
assert call_args.kwargs["organization_id"] == "org-1"
# Verify the new execution tracking fields are present
assert call_args.kwargs["executed_branch_expression"] == "{{ flag }}"
assert call_args.kwargs["executed_branch_result"] is True
assert call_args.kwargs["executed_branch_next_block"] == "next"
# executed_branch_id should be a UUID string
assert isinstance(call_args.kwargs["executed_branch_id"], str)
@pytest.mark.asyncio
async def test_prompt_branch_uses_batched_evaluation(monkeypatch: pytest.MonkeyPatch) -> None:
output_param = _output_parameter("conditional_output_prompt")
prompt_branch = BranchCondition(
criteria=PromptBranchCriteria(expression="Check if urgent"), next_block_label="next"
)
default_branch = BranchCondition(is_default=True, next_block_label=None)
block = ConditionalBlock(
label="cond_prompt",
output_parameter=output_param,
branch_conditions=[prompt_branch, default_branch],
)
ctx = DummyContext(workflow_run_id="run-2")
monkeypatch.setattr(app.WORKFLOW_CONTEXT_MANAGER, "get_workflow_run_context", lambda workflow_run_id: ctx)
# Return tuple: (results, rendered_expressions, extraction_goal, llm_response)
prompt_eval_mock = AsyncMock(return_value=([True], ["Check if urgent"], "test prompt", None))
monkeypatch.setattr(ConditionalBlock, "_evaluate_prompt_branches", prompt_eval_mock)
result = await block.execute(
workflow_run_id="run-2",
workflow_run_block_id="wrb-2",
organization_id="org-2",
)
assert result.status == BlockStatus.completed
metadata = result.output_parameter_value
assert metadata["branch_taken"] == "next"
assert metadata["criteria_type"] == "prompt"
prompt_eval_mock.assert_awaited_once()

View File

@@ -0,0 +1,232 @@
"""
Tests for FileParserBlock DOCX support.
Covers file type detection, validation, text extraction (paragraphs + tables),
token truncation, and error handling for DOCX files.
"""
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
import docx
import pytest
from skyvern.forge.sdk.workflow.exceptions import InvalidFileType
from skyvern.forge.sdk.workflow.models.block import BlockType, FileParserBlock
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
from skyvern.schemas.workflows import FileType
def _make_output_parameter(key: str) -> OutputParameter:
return OutputParameter(
parameter_type=ParameterType.OUTPUT,
key=key,
description="test",
output_parameter_id="test-output-id",
workflow_id="test-workflow-id",
created_at=datetime.now(timezone.utc),
modified_at=datetime.now(timezone.utc),
)
def _make_file_parser_block(file_url: str, file_type: FileType) -> FileParserBlock:
return FileParserBlock(
label="test_file_parser",
block_type=BlockType.FILE_URL_PARSER,
output_parameter=_make_output_parameter("test_output"),
file_url=file_url,
file_type=file_type,
)
def _create_docx(
path: Path,
paragraphs: list[str] | None = None,
table_rows: list[list[str]] | None = None,
) -> Path:
"""Create a DOCX file with optional paragraphs and tables."""
doc = docx.Document()
if paragraphs:
for text in paragraphs:
doc.add_paragraph(text)
if table_rows:
cols = len(table_rows[0])
table = doc.add_table(rows=len(table_rows), cols=cols)
for i, row_data in enumerate(table_rows):
for j, cell_text in enumerate(row_data):
table.rows[i].cells[j].text = cell_text
doc.save(str(path))
return path
class TestDetectFileTypeFromUrl:
"""Tests for _detect_file_type_from_url with DOCX extensions."""
def _detect(self, url: str) -> FileType:
block = _make_file_parser_block(url, FileType.CSV)
return block._detect_file_type_from_url(url)
def test_docx_extension(self) -> None:
assert self._detect("https://example.com/file.docx") == FileType.DOCX
def test_doc_extension_raises_error(self) -> None:
# Legacy .doc (Word 97-2003) is not supported by python-docx
with pytest.raises(InvalidFileType, match="Legacy .doc format"):
self._detect("https://example.com/file.doc")
def test_docx_with_query_params(self) -> None:
assert self._detect("https://example.com/file.docx?token=abc&v=1") == FileType.DOCX
def test_docx_case_insensitive(self) -> None:
assert self._detect("https://example.com/file.DOCX") == FileType.DOCX
def test_other_extensions_unchanged(self) -> None:
assert self._detect("https://example.com/file.pdf") == FileType.PDF
assert self._detect("https://example.com/file.xlsx") == FileType.EXCEL
assert self._detect("https://example.com/file.csv") == FileType.CSV
assert self._detect("https://example.com/file.png") == FileType.IMAGE
class TestValidateFileType:
"""Tests for validate_file_type with DOCX files."""
def test_valid_docx(self, tmp_path: Path) -> None:
path = _create_docx(tmp_path / "valid.docx", paragraphs=["Hello"])
block = _make_file_parser_block("https://example.com/valid.docx", FileType.DOCX)
# Should not raise
block.validate_file_type("https://example.com/valid.docx", str(path))
def test_plain_text_with_docx_extension(self, tmp_path: Path) -> None:
path = tmp_path / "fake.docx"
path.write_text("This is plain text, not a DOCX file.")
block = _make_file_parser_block("https://example.com/fake.docx", FileType.DOCX)
with pytest.raises(InvalidFileType):
block.validate_file_type("https://example.com/fake.docx", str(path))
def test_empty_file(self, tmp_path: Path) -> None:
path = tmp_path / "empty.docx"
path.write_bytes(b"")
block = _make_file_parser_block("https://example.com/empty.docx", FileType.DOCX)
with pytest.raises(InvalidFileType):
block.validate_file_type("https://example.com/empty.docx", str(path))
@pytest.mark.asyncio
class TestParseDocxFile:
"""Tests for _parse_docx_file text extraction."""
async def test_paragraphs_joined_by_newline(self, tmp_path: Path) -> None:
path = _create_docx(tmp_path / "paras.docx", paragraphs=["Hello", "World"])
block = _make_file_parser_block("https://example.com/paras.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path))
assert result == "Hello\nWorld"
async def test_empty_paragraphs_skipped(self, tmp_path: Path) -> None:
path = _create_docx(tmp_path / "blanks.docx", paragraphs=["Hello", "", " ", "World"])
block = _make_file_parser_block("https://example.com/blanks.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path))
assert result == "Hello\nWorld"
async def test_table_rows_formatted_with_pipe(self, tmp_path: Path) -> None:
path = _create_docx(
tmp_path / "table.docx",
table_rows=[["Name", "Age"], ["Alice", "30"]],
)
block = _make_file_parser_block("https://example.com/table.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path))
assert result == "Name | Age\nAlice | 30"
async def test_mixed_paragraphs_and_tables(self, tmp_path: Path) -> None:
path = _create_docx(
tmp_path / "mixed.docx",
paragraphs=["Intro"],
table_rows=[["Col1", "Col2"], ["A", "B"]],
)
block = _make_file_parser_block("https://example.com/mixed.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path))
assert result == "Intro\nCol1 | Col2\nA | B"
async def test_empty_document(self, tmp_path: Path) -> None:
path = _create_docx(tmp_path / "empty.docx")
block = _make_file_parser_block("https://example.com/empty.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path))
assert result == ""
async def test_empty_table_cells_skipped(self, tmp_path: Path) -> None:
path = _create_docx(
tmp_path / "sparse.docx",
table_rows=[["Name", "", "Age"], ["", "", ""]],
)
block = _make_file_parser_block("https://example.com/sparse.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path))
# First row: "Name" and "Age" (empty cell skipped), second row: all empty -> skipped
assert result == "Name | Age"
async def test_multiple_tables(self, tmp_path: Path) -> None:
doc = docx.Document()
t1 = doc.add_table(rows=1, cols=2)
t1.rows[0].cells[0].text = "T1C1"
t1.rows[0].cells[1].text = "T1C2"
t2 = doc.add_table(rows=1, cols=2)
t2.rows[0].cells[0].text = "T2C1"
t2.rows[0].cells[1].text = "T2C2"
path = tmp_path / "multi_table.docx"
doc.save(str(path))
block = _make_file_parser_block("https://example.com/multi_table.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path))
assert result == "T1C1 | T1C2\nT2C1 | T2C2"
@pytest.mark.asyncio
class TestParseDocxFileTokenTruncation:
"""Tests for _parse_docx_file token limit enforcement."""
async def test_paragraphs_truncated(self, tmp_path: Path) -> None:
# Create many paragraphs that will exceed a small token limit
paragraphs = [f"This is paragraph number {i} with some text content." for i in range(100)]
path = _create_docx(tmp_path / "long.docx", paragraphs=paragraphs)
block = _make_file_parser_block("https://example.com/long.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path), max_tokens=20)
lines = result.split("\n")
assert len(lines) < len(paragraphs)
# Each included line should be a valid paragraph
for line in lines:
assert line.startswith("This is paragraph number")
async def test_tables_truncated(self, tmp_path: Path) -> None:
table_rows = [[f"R{i}C1", f"R{i}C2", f"R{i}C3"] for i in range(100)]
path = _create_docx(tmp_path / "big_table.docx", table_rows=table_rows)
block = _make_file_parser_block("https://example.com/big_table.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path), max_tokens=20)
lines = result.split("\n")
assert len(lines) < len(table_rows)
async def test_tables_skipped_when_paragraphs_exhaust_budget(self, tmp_path: Path) -> None:
paragraphs = [f"Long paragraph {i} with lots of content to fill tokens." for i in range(100)]
table_rows = [["Should", "Not", "Appear"]]
path = _create_docx(tmp_path / "para_heavy.docx", paragraphs=paragraphs, table_rows=table_rows)
block = _make_file_parser_block("https://example.com/para_heavy.docx", FileType.DOCX)
result = await block._parse_docx_file(str(path), max_tokens=20)
assert "Should" not in result
assert "Not" not in result
assert "Appear" not in result
@pytest.mark.asyncio
class TestParseDocxFileErrorHandling:
"""Tests for _parse_docx_file error handling."""
async def test_corrupt_file(self, tmp_path: Path) -> None:
path = tmp_path / "corrupt.docx"
path.write_bytes(b"\x00\x01\x02\x03random bytes")
block = _make_file_parser_block("https://example.com/corrupt.docx", FileType.DOCX)
with pytest.raises(InvalidFileType):
await block._parse_docx_file(str(path))
async def test_nonexistent_file(self, tmp_path: Path) -> None:
block = _make_file_parser_block("https://example.com/missing.docx", FileType.DOCX)
with pytest.raises(InvalidFileType):
await block._parse_docx_file(str(tmp_path / "nonexistent.docx"))

View File

@@ -0,0 +1,46 @@
import json
class TestJsonTextParsingEquivalence:
"""Prove JSON/text parsing behavior matches aiohttp semantics.
The HttpRequestBlock parses responses using:
try:
response_body = json.loads(response_bytes.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
response_body = response_bytes.decode("utf-8", errors="replace")
This should behave equivalently to aiohttp's:
try:
response_body = await response.json()
except (aiohttp.ContentTypeError, Exception):
response_body = await response.text()
"""
def _parse_response(self, response_bytes: bytes) -> str | dict | list:
try:
return json.loads(response_bytes.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
return response_bytes.decode("utf-8", errors="replace")
def test_valid_json_utf8(self) -> None:
data = {"key": "value", "number": 42, "unicode": "日本語"}
response_bytes = json.dumps(data).encode("utf-8")
result = self._parse_response(response_bytes)
assert result == data
def test_invalid_json_returns_text(self) -> None:
response_bytes = b"not json, just text"
result = self._parse_response(response_bytes)
assert result == "not json, just text"
def test_non_utf8_bytes_handled_gracefully(self) -> None:
response_bytes = "café".encode("latin-1") # b'caf\xe9'
result = self._parse_response(response_bytes)
assert "caf" in result
assert isinstance(result, str)
def test_empty_response(self) -> None:
response_bytes = b""
result = self._parse_response(response_bytes)
assert result == ""