Remove setup.sh in favor of skyvern CLI (#4737)
This commit is contained in:
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
30
tests/unit/conftest.py
Normal file
30
tests/unit/conftest.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -- begin speed up unit tests
|
||||
import pytest
|
||||
|
||||
from tests.unit.force_stub_app import start_forge_stub_app
|
||||
|
||||
# NOTE(jdo): uncomment below to run tests faster, if you're targetting smth
|
||||
# that does not need the full app context
|
||||
|
||||
# import sys
|
||||
# from unittest.mock import MagicMock
|
||||
|
||||
# mock_modules = [
|
||||
# "skyvern.forge.app",
|
||||
# "skyvern.library",
|
||||
# "skyvern.core.script_generations.skyvern_page",
|
||||
# "skyvern.core.script_generations.run_initializer",
|
||||
# "skyvern.core.script_generations.workflow_wrappers",
|
||||
# "skyvern.services.script_service",
|
||||
# ]
|
||||
|
||||
# for module in mock_modules:
|
||||
# sys.modules[module] = MagicMock()
|
||||
|
||||
# -- end speed up unit tests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_forge_stub_app():
|
||||
start_forge_stub_app()
|
||||
yield
|
||||
60
tests/unit/force_stub_app.py
Normal file
60
tests/unit/force_stub_app.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from skyvern.forge import set_force_app_instance
|
||||
from skyvern.forge.forge_app import ForgeApp
|
||||
|
||||
|
||||
def create_forge_stub_app() -> ForgeApp:
|
||||
class _LazyNamespace:
|
||||
def __getattr__(self, name):
|
||||
value = AsyncMock()
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
|
||||
fake_app_module = ForgeApp()
|
||||
fake_app_module.DATABASE = _LazyNamespace()
|
||||
fake_app_module.WORKFLOW_CONTEXT_MANAGER = _LazyNamespace()
|
||||
fake_app_module.WORKFLOW_SERVICE = _LazyNamespace()
|
||||
fake_app_module.BROWSER_MANAGER = _LazyNamespace()
|
||||
fake_app_module.PERSISTENT_SESSIONS_MANAGER = _LazyNamespace()
|
||||
fake_app_module.ARTIFACT_MANAGER = _LazyNamespace()
|
||||
fake_app_module.AGENT_FUNCTION = _LazyNamespace()
|
||||
fake_app_module.AGENT_FUNCTION.validate_block_execution = AsyncMock()
|
||||
fake_app_module.AGENT_FUNCTION.validate_code_block = AsyncMock()
|
||||
fake_app_module.agent = _LazyNamespace()
|
||||
fake_app_module.DATABASE.update_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.create_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.create_or_update_workflow_run_output_parameter = AsyncMock()
|
||||
fake_app_module.DATABASE.get_last_task_for_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.get_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.get_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.get_task = AsyncMock()
|
||||
fake_app_module.DATABASE.update_task = AsyncMock()
|
||||
fake_app_module.DATABASE.update_task_v2 = AsyncMock()
|
||||
fake_app_module.DATABASE.get_organization = AsyncMock()
|
||||
fake_app_module.DATABASE.get_workflow = AsyncMock()
|
||||
fake_app_module.DATABASE.create_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.update_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.create_or_update_workflow_run_output_parameter = AsyncMock()
|
||||
fake_app_module.DATABASE.update_workflow_run_block = AsyncMock()
|
||||
fake_app_module.LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.SECONDARY_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.CUSTOM_SELECT_AGENT_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.NORMAL_SELECT_AGENT_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.SELECT_AGENT_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.SINGLE_INPUT_AGENT_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.EXTRACTION_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.CHECK_USER_GOAL_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.EXPERIMENTATION_PROVIDER = _LazyNamespace()
|
||||
fake_app_module.STORAGE = _LazyNamespace()
|
||||
|
||||
return fake_app_module
|
||||
|
||||
|
||||
def start_forge_stub_app() -> ForgeApp:
|
||||
force_app_instance = create_forge_stub_app()
|
||||
set_force_app_instance(force_app_instance)
|
||||
return force_app_instance
|
||||
300
tests/unit/helpers.py
Normal file
300
tests/unit/helpers.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Iterator, Sequence
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from pytest import MonkeyPatch # type: ignore[import-not-found]
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.sdk.api.llm import api_handler_factory
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.api.llm.models import LLMRouterConfig, LLMRouterModelConfig
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
|
||||
|
||||
class FakeLLMResponse:
|
||||
def __init__(self, model: str) -> None:
|
||||
self.model = model
|
||||
self._content = '{"actions": []}'
|
||||
self.choices = [
|
||||
SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
content=self._content,
|
||||
)
|
||||
)
|
||||
]
|
||||
self.usage = SimpleNamespace(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
completion_tokens_details=SimpleNamespace(reasoning_tokens=0),
|
||||
prompt_tokens_details=SimpleNamespace(cached_tokens=0),
|
||||
cache_read_input_tokens=0,
|
||||
)
|
||||
|
||||
def model_dump_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"model": self.model,
|
||||
"choices": [
|
||||
{"message": {"content": self._content}},
|
||||
],
|
||||
},
|
||||
indent=indent,
|
||||
)
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
def info(self, event: str, **kwargs: dict[str, Any]) -> None:
|
||||
self.events.append((event, kwargs))
|
||||
|
||||
def warning(self, *args, **kwargs) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
def exception(self, *args, **kwargs) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
def debug(self, *args, **kwargs) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterTestContext:
|
||||
llm_key: str
|
||||
router_config: LLMRouterConfig
|
||||
logger: DummyLogger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def router_test_context(
|
||||
monkeypatch: MonkeyPatch,
|
||||
*,
|
||||
llm_key: str,
|
||||
primary_group: str,
|
||||
fallback_group: str,
|
||||
routing_strategy: str = "simple-shuffle",
|
||||
) -> Iterator[RouterTestContext]:
|
||||
router_config = LLMRouterConfig(
|
||||
model_name="test-router",
|
||||
required_env_vars=[],
|
||||
supports_vision=False,
|
||||
add_assistant_prefix=False,
|
||||
model_list=[
|
||||
LLMRouterModelConfig(model_name=primary_group, litellm_params={"model": primary_group}),
|
||||
LLMRouterModelConfig(model_name=fallback_group, litellm_params={"model": fallback_group}),
|
||||
],
|
||||
redis_host="localhost",
|
||||
redis_port=6379,
|
||||
redis_password="",
|
||||
main_model_group=primary_group,
|
||||
fallback_model_group=fallback_group,
|
||||
routing_strategy=routing_strategy,
|
||||
num_retries=0,
|
||||
disable_cooldowns=True,
|
||||
temperature=None,
|
||||
)
|
||||
|
||||
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
|
||||
LLMConfigRegistry.register_config(llm_key, router_config)
|
||||
|
||||
logger = DummyLogger()
|
||||
monkeypatch.setattr(api_handler_factory, "LOG", logger)
|
||||
|
||||
async def fake_llm_messages_builder(prompt: str, screenshots, add_assistant_prefix: bool) -> list[dict[str, str]]:
|
||||
return [{"role": "user", "content": prompt}]
|
||||
|
||||
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", fake_llm_messages_builder)
|
||||
monkeypatch.setattr(api_handler_factory.skyvern_context, "current", lambda: None)
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda completion_response: 0.0)
|
||||
|
||||
try:
|
||||
yield RouterTestContext(llm_key=llm_key, router_config=router_config, logger=logger)
|
||||
finally:
|
||||
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def make_organization(now: datetime) -> Organization:
|
||||
return Organization(
|
||||
organization_id="org-123",
|
||||
organization_name="Org",
|
||||
webhook_callback_url=None,
|
||||
max_steps_per_run=None,
|
||||
max_retries_per_step=None,
|
||||
domain=None,
|
||||
bw_organization_id=None,
|
||||
bw_collection_ids=None,
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
def make_task(now: datetime, organization: Organization, **overrides: Any) -> Task:
|
||||
base: dict[str, Any] = {
|
||||
"title": "Task",
|
||||
"url": "https://example.com",
|
||||
"webhook_callback_url": None,
|
||||
"webhook_failure_reason": None,
|
||||
"totp_verification_url": None,
|
||||
"totp_identifier": None,
|
||||
"navigation_goal": "Find the quote",
|
||||
"data_extraction_goal": "Extract the quote",
|
||||
"navigation_payload": None,
|
||||
"error_code_mapping": None,
|
||||
"proxy_location": None,
|
||||
"extracted_information_schema": None,
|
||||
"extra_http_headers": None,
|
||||
"complete_criterion": None,
|
||||
"terminate_criterion": None,
|
||||
"task_type": TaskType.general,
|
||||
"application": None,
|
||||
"include_action_history_in_verification": False,
|
||||
"max_screenshot_scrolls": None,
|
||||
"browser_address": None,
|
||||
"download_timeout": None,
|
||||
"created_at": now,
|
||||
"modified_at": now,
|
||||
"task_id": "task-123",
|
||||
"status": TaskStatus.running,
|
||||
"extracted_information": None,
|
||||
"failure_reason": None,
|
||||
"organization_id": organization.organization_id,
|
||||
"workflow_run_id": None,
|
||||
"workflow_permanent_id": None,
|
||||
"browser_session_id": None,
|
||||
"order": 0,
|
||||
"retry": 0,
|
||||
"max_steps_per_run": None,
|
||||
"errors": [],
|
||||
"model": None,
|
||||
"queued_at": now,
|
||||
"started_at": now,
|
||||
"finished_at": None,
|
||||
}
|
||||
base.update(overrides)
|
||||
return Task(**base)
|
||||
|
||||
|
||||
def make_step(
|
||||
now: datetime,
|
||||
task: Task,
|
||||
*,
|
||||
step_id: str,
|
||||
status: StepStatus,
|
||||
order: int,
|
||||
output,
|
||||
is_last: bool = False,
|
||||
retry_index: int = 0,
|
||||
organization_id: str | None = None,
|
||||
**overrides: Any,
|
||||
) -> Step:
|
||||
base: dict[str, Any] = {
|
||||
"created_at": now,
|
||||
"modified_at": now,
|
||||
"task_id": task.task_id,
|
||||
"step_id": step_id,
|
||||
"status": status,
|
||||
"output": output,
|
||||
"order": order,
|
||||
"is_last": is_last,
|
||||
"retry_index": retry_index,
|
||||
"organization_id": organization_id or task.organization_id,
|
||||
}
|
||||
base.update(overrides)
|
||||
return Step(**base)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelVerificationMocks:
|
||||
create_step: AsyncMock
|
||||
get_task_steps: AsyncMock
|
||||
sleep: AsyncMock
|
||||
check_user_goal_complete: AsyncMock
|
||||
handle_action: AsyncMock
|
||||
create_extract_action: AsyncMock | None
|
||||
speculate_next_step_plan: AsyncMock
|
||||
persist_speculative_metadata: AsyncMock
|
||||
cancel_speculative_step: AsyncMock
|
||||
record_artifacts_after_action: AsyncMock
|
||||
update_step: AsyncMock
|
||||
update_task: AsyncMock
|
||||
|
||||
|
||||
def setup_parallel_verification_mocks(
|
||||
agent: ForgeAgent,
|
||||
*,
|
||||
step: Step,
|
||||
task: Task,
|
||||
monkeypatch: MonkeyPatch,
|
||||
next_step: Step | None,
|
||||
complete_action,
|
||||
handle_action_responses: Sequence[Any],
|
||||
extract_action: Any | None = None,
|
||||
) -> ParallelVerificationMocks:
|
||||
create_step_mock = AsyncMock(return_value=next_step)
|
||||
monkeypatch.setattr(app.DATABASE, "create_step", create_step_mock)
|
||||
|
||||
get_task_steps_mock = AsyncMock(return_value=[step])
|
||||
monkeypatch.setattr(app.DATABASE, "get_task_steps", get_task_steps_mock)
|
||||
|
||||
sleep_mock = AsyncMock(return_value=None)
|
||||
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", sleep_mock)
|
||||
|
||||
check_user_goal_complete_mock = AsyncMock(return_value=complete_action)
|
||||
monkeypatch.setattr(agent, "check_user_goal_complete", check_user_goal_complete_mock)
|
||||
|
||||
handle_action_mock = AsyncMock(side_effect=handle_action_responses)
|
||||
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", handle_action_mock)
|
||||
|
||||
speculate_mock = AsyncMock(return_value=None)
|
||||
monkeypatch.setattr(agent, "_speculate_next_step_plan", speculate_mock)
|
||||
|
||||
persist_mock = AsyncMock()
|
||||
monkeypatch.setattr(agent, "_persist_speculative_metadata_for_discarded_plan", persist_mock)
|
||||
|
||||
cancel_mock = AsyncMock()
|
||||
monkeypatch.setattr(agent, "_cancel_speculative_step", cancel_mock)
|
||||
|
||||
record_artifacts_mock = AsyncMock()
|
||||
monkeypatch.setattr(agent, "record_artifacts_after_action", record_artifacts_mock)
|
||||
|
||||
update_step_mock = AsyncMock(return_value=step)
|
||||
monkeypatch.setattr(agent, "update_step", update_step_mock)
|
||||
|
||||
update_task_mock = AsyncMock(return_value=task)
|
||||
monkeypatch.setattr(agent, "update_task", update_task_mock)
|
||||
|
||||
if extract_action is not None:
|
||||
create_extract_action_mock = AsyncMock(return_value=extract_action)
|
||||
monkeypatch.setattr(agent, "create_extract_action", create_extract_action_mock)
|
||||
else:
|
||||
create_extract_action_mock = None
|
||||
|
||||
return ParallelVerificationMocks(
|
||||
create_step=create_step_mock,
|
||||
get_task_steps=get_task_steps_mock,
|
||||
sleep=sleep_mock,
|
||||
check_user_goal_complete=check_user_goal_complete_mock,
|
||||
handle_action=handle_action_mock,
|
||||
create_extract_action=create_extract_action_mock,
|
||||
speculate_next_step_plan=speculate_mock,
|
||||
persist_speculative_metadata=persist_mock,
|
||||
cancel_speculative_step=cancel_mock,
|
||||
record_artifacts_after_action=record_artifacts_mock,
|
||||
update_step=update_step_mock,
|
||||
update_task=update_task_mock,
|
||||
)
|
||||
|
||||
|
||||
def make_browser_state() -> tuple[MagicMock, MagicMock, MagicMock]:
|
||||
return MagicMock(), MagicMock(), MagicMock()
|
||||
0
tests/unit/services/conftest.py
Normal file
0
tests/unit/services/conftest.py
Normal file
94
tests/unit/services/test_browser_recording.py
Normal file
94
tests/unit/services/test_browser_recording.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Just an example unit test for now. Will expand later.
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from skyvern.services.browser_recording.service import Processor
|
||||
from skyvern.services.browser_recording.types import (
|
||||
ExfiltratedConsoleEvent,
|
||||
)
|
||||
|
||||
ORG_ID = "org_123"
|
||||
PBS_ID = "pbs_123"
|
||||
WP_ID = "wpid_123"
|
||||
|
||||
|
||||
def make_console_event(
|
||||
params: dict[str, t.Any],
|
||||
timestamp: float,
|
||||
) -> ExfiltratedConsoleEvent:
|
||||
default_params = {
|
||||
"url": "https://example.com",
|
||||
"activeElement": {
|
||||
"tagName": "BUTTON",
|
||||
},
|
||||
"window": {
|
||||
"height": 800,
|
||||
"width": 1200,
|
||||
"scrollX": 0,
|
||||
"scrollY": 0,
|
||||
},
|
||||
"mousePosition": {"xp": 0.5, "yp": 0.5},
|
||||
}
|
||||
|
||||
params = {**default_params, **params}
|
||||
|
||||
return ExfiltratedConsoleEvent(
|
||||
kind="exfiltrated-event",
|
||||
source="console",
|
||||
event_name="user-interaction",
|
||||
params=params,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
|
||||
def make_mouseenter_event(
|
||||
target: dict[str, t.Any],
|
||||
timestamp: float,
|
||||
) -> ExfiltratedConsoleEvent:
|
||||
params: dict[str, t.Any] = {
|
||||
"type": "mouseenter",
|
||||
"target": target,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
return make_console_event(
|
||||
params=params,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
|
||||
def make_mouseleave_event(
|
||||
target: dict[str, t.Any],
|
||||
timestamp: float,
|
||||
) -> ExfiltratedConsoleEvent:
|
||||
params: dict[str, t.Any] = {
|
||||
"type": "mouseleave",
|
||||
"target": target,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
return make_console_event(
|
||||
params=params,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
|
||||
def test_hover() -> None:
|
||||
target = dict(id="button-1", skyId="sky-123", text=["Click me"])
|
||||
|
||||
event1 = make_mouseenter_event(
|
||||
target=target,
|
||||
timestamp=1000.0,
|
||||
)
|
||||
|
||||
event2 = make_mouseleave_event(
|
||||
target=target,
|
||||
timestamp=4000.0,
|
||||
)
|
||||
|
||||
processor = Processor(PBS_ID, ORG_ID, WP_ID)
|
||||
actions = processor.events_to_actions([event1, event2])
|
||||
|
||||
assert len(actions) == 1
|
||||
115
tests/unit/test_actions.py
Normal file
115
tests/unit/test_actions.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from skyvern.webeye.actions.actions import Action, KeypressAction, NullAction, WebAction
|
||||
from skyvern.webeye.actions.parse_actions import parse_action
|
||||
|
||||
|
||||
def _mock_scraped_page() -> MagicMock:
|
||||
page = MagicMock()
|
||||
page.id_to_element_hash = {}
|
||||
page.id_to_element_dict = {}
|
||||
return page
|
||||
|
||||
|
||||
def test_action_parse__no_element_id() -> None:
|
||||
action_no_element_id = {
|
||||
"action_type": "click",
|
||||
}
|
||||
action = Action.model_validate(action_no_element_id)
|
||||
assert action.action_type == "click"
|
||||
assert action.element_id is None
|
||||
|
||||
|
||||
def test_action_parse__with_element_id() -> None:
|
||||
action_no_element_id_str = {
|
||||
"action_type": "click",
|
||||
"element_id": "element_id",
|
||||
}
|
||||
action = Action.model_validate(action_no_element_id_str)
|
||||
assert action.action_type == "click"
|
||||
assert action.element_id == "element_id"
|
||||
|
||||
action_no_element_id_int = {
|
||||
"action_type": "click",
|
||||
"element_id": 1,
|
||||
}
|
||||
action = Action.model_validate(action_no_element_id_int)
|
||||
assert action.action_type == "click"
|
||||
assert action.element_id == "1"
|
||||
|
||||
|
||||
def test_web_action_parse__no_element_id() -> None:
|
||||
action_no_element_id = {
|
||||
"action_type": "click",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
WebAction.model_validate(action_no_element_id)
|
||||
|
||||
|
||||
def test_web_action_parse__with_element_id() -> None:
|
||||
action_no_element_id_str = {
|
||||
"action_type": "click",
|
||||
"element_id": "element_id",
|
||||
}
|
||||
action = WebAction.model_validate(action_no_element_id_str)
|
||||
assert action.action_type == "click"
|
||||
assert action.element_id == "element_id"
|
||||
|
||||
action_no_element_id_int = {
|
||||
"action_type": "click",
|
||||
"element_id": 1,
|
||||
}
|
||||
action = WebAction.model_validate(action_no_element_id_int)
|
||||
assert action.action_type == "click"
|
||||
assert action.element_id == "1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("key", ["Enter", "Tab", "Escape", "ArrowDown", "ArrowUp"])
|
||||
def test_parse_keypress_valid_keys(key: str) -> None:
|
||||
action = parse_action(
|
||||
action={"action_type": "KEYPRESS", "key": key, "reasoning": "test"},
|
||||
scraped_page=_mock_scraped_page(),
|
||||
)
|
||||
assert isinstance(action, KeypressAction)
|
||||
assert action.keys == [key]
|
||||
assert action.element_id is None
|
||||
assert action.skyvern_element_hash is None
|
||||
assert action.skyvern_element_data is None
|
||||
|
||||
|
||||
def test_parse_keypress_invalid_key_returns_null_action() -> None:
|
||||
action = parse_action(
|
||||
action={"action_type": "KEYPRESS", "key": "Delete", "reasoning": "test"},
|
||||
scraped_page=_mock_scraped_page(),
|
||||
)
|
||||
assert isinstance(action, NullAction)
|
||||
|
||||
|
||||
def test_parse_keypress_backward_compat_press_enter() -> None:
|
||||
action = parse_action(
|
||||
action={"action_type": "PRESS_ENTER", "key": "Enter", "reasoning": "test"},
|
||||
scraped_page=_mock_scraped_page(),
|
||||
)
|
||||
assert isinstance(action, KeypressAction)
|
||||
assert action.keys == ["Enter"]
|
||||
|
||||
|
||||
def test_parse_keypress_keys_list() -> None:
|
||||
action = parse_action(
|
||||
action={"action_type": "KEYPRESS", "keys": ["Enter"], "reasoning": "test"},
|
||||
scraped_page=_mock_scraped_page(),
|
||||
)
|
||||
assert isinstance(action, KeypressAction)
|
||||
assert action.keys == ["Enter"]
|
||||
|
||||
|
||||
def test_parse_keypress_no_key_defaults_to_enter() -> None:
|
||||
action = parse_action(
|
||||
action={"action_type": "KEYPRESS", "reasoning": "test"},
|
||||
scraped_page=_mock_scraped_page(),
|
||||
)
|
||||
assert isinstance(action, KeypressAction)
|
||||
assert action.keys == ["Enter"]
|
||||
165
tests/unit/test_ai_click_empty_actions.py
Normal file
165
tests/unit/test_ai_click_empty_actions.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Tests for ai_click behavior when LLM returns empty actions.
|
||||
|
||||
This tests the fix for SKY-7577 where cached click actions were succeeding
|
||||
even when the target element didn't exist on the page.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.core.script_generations.real_skyvern_page_ai import RealSkyvernPageAi
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_page():
|
||||
"""Create a mock Playwright page."""
|
||||
page = MagicMock()
|
||||
page.url = "https://example.com"
|
||||
mock_locator = MagicMock()
|
||||
mock_locator.click = AsyncMock()
|
||||
page.locator = MagicMock(return_value=mock_locator)
|
||||
return page
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scraped_page():
|
||||
"""Create a mock ScrapedPage that properly supports async methods."""
|
||||
scraped_page = MagicMock()
|
||||
scraped_page.build_element_tree = MagicMock(return_value="<element_tree>")
|
||||
# The generate_scraped_page method is async and returns self
|
||||
scraped_page.generate_scraped_page = AsyncMock(return_value=scraped_page)
|
||||
return scraped_page
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock skyvern context."""
|
||||
context = MagicMock()
|
||||
context.organization_id = "org_123"
|
||||
context.task_id = "task_123"
|
||||
context.step_id = "step_123"
|
||||
context.prompt = "Test prompt"
|
||||
context.tz_info = None
|
||||
return context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
"""Create a mock app with SINGLE_CLICK_AGENT_LLM_API_HANDLER."""
|
||||
mock = MagicMock()
|
||||
mock.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(return_value={"actions": []})
|
||||
mock.DATABASE = MagicMock()
|
||||
mock.DATABASE.get_step = AsyncMock(return_value=MagicMock())
|
||||
return mock
|
||||
|
||||
|
||||
class TestAiClickEmptyActions:
|
||||
"""Test that ai_click properly fails when LLM returns no actions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_click_raises_when_llm_returns_empty_actions_no_selector(
|
||||
self, mock_page, mock_scraped_page, mock_context, mock_app
|
||||
):
|
||||
"""
|
||||
When the LLM returns no actions (element doesn't exist on page) and
|
||||
there's no selector to fall back to, ai_click should raise an exception.
|
||||
"""
|
||||
real_skyvern_page_ai = RealSkyvernPageAi(mock_scraped_page, mock_page)
|
||||
|
||||
mock_app.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(return_value={"actions": []})
|
||||
|
||||
with (
|
||||
patch.object(real_skyvern_page_ai, "_refresh_scraped_page", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.core.script_generations.real_skyvern_page_ai.skyvern_context.ensure_context",
|
||||
return_value=mock_context,
|
||||
),
|
||||
patch("skyvern.core.script_generations.real_skyvern_page_ai.app", mock_app),
|
||||
patch(
|
||||
"skyvern.core.script_generations.real_skyvern_page_ai.prompt_engine.load_prompt",
|
||||
return_value="mock_prompt",
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await real_skyvern_page_ai.ai_click(
|
||||
selector=None, # No fallback selector
|
||||
intention="Click the download button",
|
||||
)
|
||||
|
||||
# Should raise because no actions and no fallback
|
||||
assert "AI click failed" in str(exc_info.value) or "AI could not find" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_click_raises_when_llm_call_fails_no_selector(
|
||||
self, mock_page, mock_scraped_page, mock_context, mock_app
|
||||
):
|
||||
"""
|
||||
When AI fails (exception) and there's no selector to fall back to,
|
||||
ai_click should raise an exception.
|
||||
"""
|
||||
real_skyvern_page_ai = RealSkyvernPageAi(mock_scraped_page, mock_page)
|
||||
|
||||
mock_app.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(side_effect=Exception("LLM error"))
|
||||
|
||||
with (
|
||||
patch.object(real_skyvern_page_ai, "_refresh_scraped_page", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.core.script_generations.real_skyvern_page_ai.skyvern_context.ensure_context",
|
||||
return_value=mock_context,
|
||||
),
|
||||
patch("skyvern.core.script_generations.real_skyvern_page_ai.app", mock_app),
|
||||
patch(
|
||||
"skyvern.core.script_generations.real_skyvern_page_ai.prompt_engine.load_prompt",
|
||||
return_value="mock_prompt",
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await real_skyvern_page_ai.ai_click(
|
||||
selector=None, # No fallback selector
|
||||
intention="Click the download button",
|
||||
)
|
||||
|
||||
assert "AI click failed" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_click_falls_back_to_selector_when_llm_returns_empty(
|
||||
self, mock_page, mock_scraped_page, mock_context, mock_app
|
||||
):
|
||||
"""
|
||||
When AI returns empty actions but there IS a selector to fall back to,
|
||||
ai_click should use the selector and succeed.
|
||||
"""
|
||||
# Set up the locator mock properly with AsyncMock for click
|
||||
mock_locator = MagicMock()
|
||||
mock_locator.click = AsyncMock()
|
||||
mock_page.locator = MagicMock(return_value=mock_locator)
|
||||
|
||||
real_skyvern_page_ai = RealSkyvernPageAi(mock_scraped_page, mock_page)
|
||||
|
||||
mock_app.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(return_value={"actions": []})
|
||||
|
||||
with (
|
||||
patch.object(real_skyvern_page_ai, "_refresh_scraped_page", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.core.script_generations.real_skyvern_page_ai.skyvern_context.ensure_context",
|
||||
return_value=mock_context,
|
||||
),
|
||||
patch("skyvern.core.script_generations.real_skyvern_page_ai.app", mock_app),
|
||||
patch(
|
||||
"skyvern.core.script_generations.real_skyvern_page_ai.prompt_engine.load_prompt",
|
||||
return_value="mock_prompt",
|
||||
),
|
||||
):
|
||||
# Should NOT raise because we have a fallback selector
|
||||
result = await real_skyvern_page_ai.ai_click(
|
||||
selector="xpath=//button[@id='download']", # Has fallback
|
||||
intention="Click the download button",
|
||||
)
|
||||
|
||||
# Should have used the fallback selector
|
||||
mock_page.locator.assert_called_once_with("xpath=//button[@id='download']")
|
||||
assert result == "xpath=//button[@id='download']"
|
||||
1110
tests/unit/test_aiohttp_helper.py
Normal file
1110
tests/unit/test_aiohttp_helper.py
Normal file
File diff suppressed because it is too large
Load Diff
91
tests/unit/test_api_handler_cached_content_fix.py
Normal file
91
tests/unit/test_api_handler_cached_content_fix.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import skyvern.forge as forge_module
|
||||
import skyvern.forge.sdk.core.skyvern_context as skyvern_context_module
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
|
||||
# Replace the forge app holder with a MagicMock so test imports don't require a fully
|
||||
# initialised ForgeApp instance.
|
||||
forge_module.set_force_app_instance(MagicMock())
|
||||
forge_module.app.EXPERIMENTATION_PROVIDER = MagicMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_content_removed_from_non_extract_prompts() -> None:
|
||||
mock_config = MagicMock(spec=LLMConfig)
|
||||
mock_config.model_name = "gemini-2.5-pro"
|
||||
mock_config.litellm_params = {}
|
||||
mock_config.supports_vision = False
|
||||
mock_config.add_assistant_prefix = False
|
||||
mock_config.max_completion_tokens = 100
|
||||
mock_config.max_tokens = None
|
||||
mock_config.temperature = 0.0
|
||||
mock_config.reasoning_effort = None
|
||||
mock_config.disable_cooldowns = True
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump_json.return_value = "{}"
|
||||
mock_response.choices = [MagicMock(message=MagicMock(content="test"))]
|
||||
mock_response.usage = MagicMock(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=10,
|
||||
completion_tokens_details=None,
|
||||
prompt_tokens_details=None,
|
||||
cache_read_input_tokens=0,
|
||||
)
|
||||
|
||||
# Ensure app dependencies referenced inside the handler resolve to async mocks.
|
||||
forge_module.app.ARTIFACT_MANAGER = MagicMock()
|
||||
forge_module.app.DATABASE = MagicMock()
|
||||
forge_module.app.DATABASE.update_step = AsyncMock()
|
||||
forge_module.app.DATABASE.update_thought = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", return_value=mock_config),
|
||||
patch(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.litellm.acompletion",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_acompletion,
|
||||
patch(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.litellm.completion_cost",
|
||||
return_value=0.001,
|
||||
),
|
||||
patch(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.llm_messages_builder", new_callable=AsyncMock
|
||||
) as mock_builder,
|
||||
patch("skyvern.forge.sdk.api.llm.api_handler_factory.parse_api_response", return_value={}),
|
||||
):
|
||||
mock_builder.return_value = [{"role": "user", "content": "test"}]
|
||||
mock_acompletion.return_value = mock_response
|
||||
|
||||
handler = LLMAPIHandlerFactory.get_llm_api_handler("gemini-2.5-pro")
|
||||
|
||||
context = SkyvernContext()
|
||||
context.cached_static_prompt = "some static prompt"
|
||||
context.use_prompt_caching = True
|
||||
context.vertex_cache_name = "projects/123/locations/global/cachedContents/demo"
|
||||
|
||||
token = skyvern_context_module._context.set(context)
|
||||
try:
|
||||
# Extract actions attaches cached_content.
|
||||
await handler(prompt="test", prompt_name="extract-actions")
|
||||
args, kwargs = mock_acompletion.call_args
|
||||
assert kwargs.get("cached_content") == "projects/123/locations/global/cachedContents/demo"
|
||||
|
||||
# Non-extract prompt should not include cached_content.
|
||||
mock_acompletion.reset_mock()
|
||||
await handler(prompt="test", prompt_name="check-user-goal")
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert "cached_content" not in kwargs
|
||||
|
||||
# Even if user supplied cached_content manually, it must be stripped.
|
||||
mock_acompletion.reset_mock()
|
||||
await handler(prompt="test", prompt_name="check-user-goal", parameters={"cached_content": "leaked"})
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert "cached_content" not in kwargs
|
||||
finally:
|
||||
skyvern_context_module._context.reset(token)
|
||||
226
tests/unit/test_api_handler_factory.py
Normal file
226
tests/unit/test_api_handler_factory.py
Normal file
@@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
|
||||
from skyvern.forge.sdk.api.llm import api_handler_factory
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import (
|
||||
EXTRACT_ACTION_PROMPT_NAME,
|
||||
LLMAPIHandlerFactory,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
||||
from tests.unit.helpers import FakeLLMResponse
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_content_not_added_for_non_gemini(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that cached_content is NOT added to non-Gemini models like GPT-4."""
|
||||
# Setup context with caching enabled
|
||||
context = MagicMock()
|
||||
context.vertex_cache_name = "projects/123/locations/us-central1/cachedContents/456"
|
||||
context.use_prompt_caching = True
|
||||
context.cached_static_prompt = "some static prompt"
|
||||
context.hashed_href_map = {}
|
||||
|
||||
# Setup non-Gemini config
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-4",
|
||||
required_env_vars=[],
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
||||
)
|
||||
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
||||
monkeypatch.setattr(
|
||||
api_handler_factory, "llm_messages_builder", AsyncMock(return_value=[{"role": "user", "content": "test"}])
|
||||
)
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
|
||||
|
||||
# Mock litellm.acompletion to capture the parameters
|
||||
completion_params = {}
|
||||
|
||||
async def mock_acompletion(*args, **kwargs):
|
||||
completion_params.update(kwargs)
|
||||
return FakeLLMResponse("gpt-4")
|
||||
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
||||
|
||||
# Get handler and call it
|
||||
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
|
||||
await handler(prompt="test prompt", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
|
||||
|
||||
# Verify cached_content was NOT passed
|
||||
assert "cached_content" not in completion_params
|
||||
assert completion_params["model"] == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_content_added_for_gemini(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that cached_content IS added for Gemini models."""
|
||||
# Setup context with caching enabled
|
||||
context = MagicMock()
|
||||
context.vertex_cache_name = "projects/123/locations/us-central1/cachedContents/456"
|
||||
context.use_prompt_caching = True
|
||||
context.cached_static_prompt = "some static prompt"
|
||||
context.hashed_href_map = {}
|
||||
|
||||
# Setup Gemini config
|
||||
llm_config = LLMConfig(
|
||||
model_name="gemini-1.5-pro",
|
||||
required_env_vars=[],
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
||||
)
|
||||
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
||||
monkeypatch.setattr(
|
||||
api_handler_factory, "llm_messages_builder", AsyncMock(return_value=[{"role": "user", "content": "test"}])
|
||||
)
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
|
||||
|
||||
# Mock litellm.acompletion to capture the parameters
|
||||
completion_params = {}
|
||||
|
||||
async def mock_acompletion(*args, **kwargs):
|
||||
completion_params.update(kwargs)
|
||||
return FakeLLMResponse("gemini-1.5-pro")
|
||||
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
||||
|
||||
# Get handler and call it
|
||||
handler = LLMAPIHandlerFactory.get_llm_api_handler("gemini-1.5-pro")
|
||||
await handler(prompt="test prompt", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
|
||||
|
||||
# Verify cached_content WAS passed
|
||||
assert "cached_content" in completion_params
|
||||
assert completion_params["cached_content"] == "projects/123/locations/us-central1/cachedContents/456"
|
||||
assert completion_params["model"] == "gemini-1.5-pro"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_caching_not_injected_for_check_user_goal(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that OpenAI context caching system message is NOT injected for check-user-goal prompts.
|
||||
|
||||
This is a regression test for a bug where the extract-action-static.j2 prompt was being
|
||||
injected as a system message for ALL prompts on OpenAI models, causing the LLM to return
|
||||
CLICK actions when running check-user-goal (which should only return COMPLETE/TERMINATE).
|
||||
"""
|
||||
# Setup context with caching enabled (simulating state after extract-action ran)
|
||||
context = MagicMock()
|
||||
context.vertex_cache_name = None
|
||||
context.use_prompt_caching = True
|
||||
context.cached_static_prompt = "This is the extract-action-static prompt content"
|
||||
context.hashed_href_map = {}
|
||||
|
||||
# Setup OpenAI config (GPT-4)
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-4",
|
||||
required_env_vars=[],
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
||||
)
|
||||
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
||||
|
||||
# Capture messages passed to LLM
|
||||
captured_messages: list = []
|
||||
|
||||
async def mock_llm_messages_builder(prompt, screenshots, add_assistant_prefix):
|
||||
return [{"role": "user", "content": prompt}]
|
||||
|
||||
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", mock_llm_messages_builder)
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
|
||||
|
||||
async def mock_acompletion(*args, **kwargs):
|
||||
captured_messages.extend(kwargs.get("messages", []))
|
||||
return FakeLLMResponse("gpt-4")
|
||||
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
||||
|
||||
# Get handler and call it with check-user-goal prompt (NOT extract-actions)
|
||||
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
|
||||
await handler(prompt="check-user-goal prompt content", prompt_name="check-user-goal")
|
||||
|
||||
# Verify the cached_static_prompt was NOT injected as a system message
|
||||
# There should only be the user message, no system message with the cached content
|
||||
system_messages = [m for m in captured_messages if m.get("role") == "system"]
|
||||
assert len(system_messages) == 0, (
|
||||
f"Expected no system messages with cached content for check-user-goal, but found: {system_messages}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_caching_injected_for_extract_actions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that OpenAI context caching system message IS injected for extract-actions prompts."""
|
||||
# Setup context with caching enabled
|
||||
context = MagicMock()
|
||||
context.vertex_cache_name = None
|
||||
context.use_prompt_caching = True
|
||||
context.cached_static_prompt = "This is the extract-action-static prompt content"
|
||||
context.hashed_href_map = {}
|
||||
|
||||
# Setup OpenAI config (GPT-4)
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-4",
|
||||
required_env_vars=[],
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
||||
)
|
||||
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
||||
|
||||
# Capture messages passed to LLM
|
||||
captured_messages: list = []
|
||||
|
||||
async def mock_llm_messages_builder(prompt, screenshots, add_assistant_prefix):
|
||||
return [{"role": "user", "content": prompt}]
|
||||
|
||||
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", mock_llm_messages_builder)
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
|
||||
|
||||
async def mock_acompletion(*args, **kwargs):
|
||||
captured_messages.extend(kwargs.get("messages", []))
|
||||
return FakeLLMResponse("gpt-4")
|
||||
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
||||
|
||||
# Get handler and call it with extract-actions prompt
|
||||
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
|
||||
await handler(prompt="extract-actions prompt content", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
|
||||
|
||||
# Verify the cached_static_prompt WAS injected as a system message
|
||||
system_messages = [m for m in captured_messages if m.get("role") == "system"]
|
||||
assert len(system_messages) == 1, (
|
||||
f"Expected 1 system message with cached content for extract-actions, "
|
||||
f"but found {len(system_messages)}: {system_messages}"
|
||||
)
|
||||
# Check the system message contains the cached content
|
||||
system_content = system_messages[0].get("content", [])
|
||||
assert any(part.get("text") == "This is the extract-action-static prompt content" for part in system_content), (
|
||||
f"System message should contain cached_static_prompt, got: {system_content}"
|
||||
)
|
||||
524
tests/unit/test_auto_completion_location.py
Normal file
524
tests/unit/test_auto_completion_location.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""Tests for the location auto-completion fast-path optimisation.
|
||||
|
||||
When the user types an address into a location field and exactly one autocomplete
|
||||
suggestion appears, we skip the LLM call and click the suggestion directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.constants import SKYVERN_ID_ATTR
|
||||
from skyvern.forge.sdk.models import StepStatus
|
||||
from skyvern.webeye.actions.actions import InputOrSelectContext
|
||||
from skyvern.webeye.actions.handler import (
|
||||
AutoCompletionResult,
|
||||
choose_auto_completion_dropdown,
|
||||
input_or_auto_complete_input,
|
||||
)
|
||||
from skyvern.webeye.actions.responses import ActionSuccess
|
||||
from tests.unit.helpers import make_organization, make_step, make_task
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NOW = datetime.now(UTC)
|
||||
_ORG = make_organization(_NOW)
|
||||
_TASK = make_task(_NOW, _ORG, navigation_payload={"address": "123 Main St"})
|
||||
_STEP = make_step(_NOW, _TASK, step_id="stp-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
SINGLE_ELEMENT = [{"id": "AAAA", "tag": "div", "text": "123 Main St, Springfield, IL"}]
|
||||
MULTI_ELEMENTS = [
|
||||
{"id": "AAAA", "tag": "div", "text": "123 Main St, Springfield, IL"},
|
||||
{"id": "AAAB", "tag": "div", "text": "123 Main St, Springfield, MO"},
|
||||
]
|
||||
|
||||
|
||||
def _make_location_context(**overrides: object) -> InputOrSelectContext:
|
||||
defaults = {
|
||||
"field": "Address",
|
||||
"is_location_input": True,
|
||||
"is_search_bar": False,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return InputOrSelectContext(**defaults)
|
||||
|
||||
|
||||
def _make_non_location_context(**overrides: object) -> InputOrSelectContext:
|
||||
defaults = {
|
||||
"field": "Search",
|
||||
"is_location_input": False,
|
||||
"is_search_bar": False,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return InputOrSelectContext(**defaults)
|
||||
|
||||
|
||||
def _mock_skyvern_element(frame: MagicMock | None = None) -> MagicMock:
|
||||
"""Return a mock SkyvernElement whose helpers are async-safe."""
|
||||
el = MagicMock()
|
||||
el.get_id.return_value = "elem-1"
|
||||
el.get_frame.return_value = frame or _mock_frame()
|
||||
el.get_frame_id.return_value = "frame-1"
|
||||
el.is_interactable.return_value = True
|
||||
el.press_fill = AsyncMock()
|
||||
el.input_clear = AsyncMock()
|
||||
el.is_visible = AsyncMock(return_value=True)
|
||||
el.get_element_handler = AsyncMock(return_value=MagicMock())
|
||||
return el
|
||||
|
||||
|
||||
def _mock_frame(locator_count: int = 1) -> MagicMock:
|
||||
"""Return a mock Playwright Frame with a configurable locator."""
|
||||
frame = MagicMock()
|
||||
locator = MagicMock()
|
||||
locator.count = AsyncMock(return_value=locator_count)
|
||||
locator.click = AsyncMock()
|
||||
frame.locator.return_value = locator
|
||||
return frame
|
||||
|
||||
|
||||
def _mock_incremental_scrape(elements: list[dict]) -> MagicMock:
|
||||
"""Return a mock IncrementalScrapePage that yields *elements*."""
|
||||
inc = MagicMock()
|
||||
inc.start_listen_dom_increment = AsyncMock()
|
||||
inc.stop_listen_dom_increment = AsyncMock()
|
||||
inc.get_incremental_element_tree = AsyncMock(return_value=copy.deepcopy(elements))
|
||||
inc.build_html_tree.return_value = "<div>mocked</div>"
|
||||
return inc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for choose_auto_completion_dropdown
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_location_single_option_skips_llm() -> None:
|
||||
"""When is_location_input=True and exactly 1 option appears, the LLM must NOT be called."""
|
||||
frame = _mock_frame(locator_count=1)
|
||||
skyvern_el = _mock_skyvern_element(frame)
|
||||
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
|
||||
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
|
||||
),
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.IncrementalScrapePage",
|
||||
return_value=inc_scrape,
|
||||
),
|
||||
patch("skyvern.webeye.actions.handler.app") as mock_app,
|
||||
):
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
|
||||
|
||||
result = await choose_auto_completion_dropdown(
|
||||
context=_make_location_context(),
|
||||
page=MagicMock(),
|
||||
scraped_page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="123 Main St",
|
||||
skyvern_element=skyvern_el,
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
is_location_input=True,
|
||||
)
|
||||
|
||||
# The LLM should never have been called
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_not_called()
|
||||
|
||||
# The locator should have been clicked
|
||||
frame.locator.assert_called_with(f'[{SKYVERN_ID_ATTR}="AAAA"]')
|
||||
frame.locator.return_value.click.assert_awaited_once()
|
||||
|
||||
# Result should indicate success
|
||||
assert isinstance(result.action_result, ActionSuccess)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_location_whitespace_normalized_still_matches() -> None:
|
||||
"""Input with extra whitespace should still match after normalization."""
|
||||
frame = _mock_frame(locator_count=1)
|
||||
skyvern_el = _mock_skyvern_element(frame)
|
||||
# Option has single spaces, input will have double spaces
|
||||
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
|
||||
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
|
||||
),
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.IncrementalScrapePage",
|
||||
return_value=inc_scrape,
|
||||
),
|
||||
patch("skyvern.webeye.actions.handler.app") as mock_app,
|
||||
):
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
|
||||
|
||||
result = await choose_auto_completion_dropdown(
|
||||
context=_make_location_context(),
|
||||
page=MagicMock(),
|
||||
scraped_page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="123 Main St", # Double spaces - should still match after normalization
|
||||
skyvern_element=skyvern_el,
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
is_location_input=True,
|
||||
)
|
||||
|
||||
# LLM should NOT be called - whitespace normalization should make it match
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_not_called()
|
||||
assert isinstance(result.action_result, ActionSuccess)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_location_multiple_options_calls_llm() -> None:
|
||||
"""When is_location_input=True but multiple options appear, the LLM IS called."""
|
||||
frame = _mock_frame(locator_count=1)
|
||||
skyvern_el = _mock_skyvern_element(frame)
|
||||
inc_scrape = _mock_incremental_scrape(MULTI_ELEMENTS)
|
||||
|
||||
llm_response = {
|
||||
"auto_completion_attempt": True,
|
||||
"relevance_float": 0.95,
|
||||
"id": "AAAA",
|
||||
"direct_searching": False,
|
||||
"reasoning": "First option matches",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
|
||||
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
|
||||
),
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.IncrementalScrapePage",
|
||||
return_value=inc_scrape,
|
||||
),
|
||||
patch("skyvern.webeye.actions.handler.app") as mock_app,
|
||||
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
|
||||
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
|
||||
):
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
|
||||
mock_app.AGENT_FUNCTION = MagicMock()
|
||||
mock_prompt.load_prompt.return_value = "mocked prompt"
|
||||
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
|
||||
|
||||
await choose_auto_completion_dropdown(
|
||||
context=_make_location_context(),
|
||||
page=MagicMock(),
|
||||
scraped_page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="123 Main St",
|
||||
skyvern_element=skyvern_el,
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
is_location_input=True,
|
||||
)
|
||||
|
||||
# LLM should have been called because there are 2 options
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_location_single_option_calls_llm() -> None:
|
||||
"""When is_location_input=False, even a single option goes through the LLM path."""
|
||||
frame = _mock_frame(locator_count=1)
|
||||
skyvern_el = _mock_skyvern_element(frame)
|
||||
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
|
||||
|
||||
llm_response = {
|
||||
"auto_completion_attempt": True,
|
||||
"relevance_float": 0.95,
|
||||
"id": "AAAA",
|
||||
"direct_searching": False,
|
||||
"reasoning": "Matches",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
|
||||
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
|
||||
),
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.IncrementalScrapePage",
|
||||
return_value=inc_scrape,
|
||||
),
|
||||
patch("skyvern.webeye.actions.handler.app") as mock_app,
|
||||
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
|
||||
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
|
||||
):
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
|
||||
mock_app.AGENT_FUNCTION = MagicMock()
|
||||
mock_prompt.load_prompt.return_value = "mocked prompt"
|
||||
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
|
||||
|
||||
await choose_auto_completion_dropdown(
|
||||
context=_make_non_location_context(),
|
||||
page=MagicMock(),
|
||||
scraped_page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="some search",
|
||||
skyvern_element=skyvern_el,
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
is_location_input=False,
|
||||
)
|
||||
|
||||
# LLM should be called — no fast-path for non-location inputs
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_location_fast_path_returns_action_success() -> None:
|
||||
"""The fast-path must set action_result to ActionSuccess on the result object."""
|
||||
frame = _mock_frame(locator_count=1)
|
||||
skyvern_el = _mock_skyvern_element(frame)
|
||||
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
|
||||
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
|
||||
),
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.IncrementalScrapePage",
|
||||
return_value=inc_scrape,
|
||||
),
|
||||
patch("skyvern.webeye.actions.handler.app") as mock_app,
|
||||
):
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
|
||||
|
||||
result = await choose_auto_completion_dropdown(
|
||||
context=_make_location_context(),
|
||||
page=MagicMock(),
|
||||
scraped_page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="123 Main St",
|
||||
skyvern_element=skyvern_el,
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
is_location_input=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, AutoCompletionResult)
|
||||
assert isinstance(result.action_result, ActionSuccess)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_location_fast_path_element_not_in_dom_falls_through() -> None:
|
||||
"""If the single element's locator has count 0, the fast-path is skipped."""
|
||||
frame = _mock_frame(locator_count=0) # element not found in DOM
|
||||
skyvern_el = _mock_skyvern_element(frame)
|
||||
inc_scrape = _mock_incremental_scrape(SINGLE_ELEMENT)
|
||||
|
||||
llm_response = {
|
||||
"auto_completion_attempt": True,
|
||||
"relevance_float": 0.95,
|
||||
"id": "AAAA",
|
||||
"direct_searching": False,
|
||||
"reasoning": "Matches",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
|
||||
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
|
||||
),
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.IncrementalScrapePage",
|
||||
return_value=inc_scrape,
|
||||
),
|
||||
patch("skyvern.webeye.actions.handler.app") as mock_app,
|
||||
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
|
||||
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
|
||||
):
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
|
||||
mock_app.AGENT_FUNCTION = MagicMock()
|
||||
mock_prompt.load_prompt.return_value = "mocked prompt"
|
||||
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
|
||||
|
||||
# Should fall through to LLM path because locator.count() == 0
|
||||
await choose_auto_completion_dropdown(
|
||||
context=_make_location_context(),
|
||||
page=MagicMock(),
|
||||
scraped_page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="123 Main St",
|
||||
skyvern_element=skyvern_el,
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
is_location_input=True,
|
||||
)
|
||||
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for input_or_auto_complete_input flag propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_or_auto_complete_passes_is_location_input() -> None:
|
||||
"""input_or_auto_complete_input must forward is_location_input to choose_auto_completion_dropdown."""
|
||||
context = _make_location_context()
|
||||
|
||||
with patch(
|
||||
"skyvern.webeye.actions.handler.choose_auto_completion_dropdown",
|
||||
new=AsyncMock(return_value=AutoCompletionResult(action_result=ActionSuccess())),
|
||||
) as mock_choose:
|
||||
result = await input_or_auto_complete_input(
|
||||
input_or_select_context=context,
|
||||
scraped_page=MagicMock(),
|
||||
page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="123 Main St",
|
||||
skyvern_element=_mock_skyvern_element(),
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
)
|
||||
|
||||
assert isinstance(result, ActionSuccess)
|
||||
# Verify is_location_input was passed
|
||||
call_kwargs = mock_choose.call_args.kwargs
|
||||
assert call_kwargs["is_location_input"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_or_auto_complete_passes_false_for_non_location() -> None:
|
||||
"""When is_location_input is None/False, the flag should be passed as False."""
|
||||
context = _make_non_location_context()
|
||||
|
||||
with patch(
|
||||
"skyvern.webeye.actions.handler.choose_auto_completion_dropdown",
|
||||
new=AsyncMock(return_value=AutoCompletionResult(action_result=ActionSuccess())),
|
||||
) as mock_choose:
|
||||
result = await input_or_auto_complete_input(
|
||||
input_or_select_context=context,
|
||||
scraped_page=MagicMock(),
|
||||
page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="some query",
|
||||
skyvern_element=_mock_skyvern_element(),
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
)
|
||||
|
||||
assert isinstance(result, ActionSuccess)
|
||||
call_kwargs = mock_choose.call_args.kwargs
|
||||
assert call_kwargs["is_location_input"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: options that don't contain the input fall through to LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NO_RESULT_ELEMENTS = [{"id": "AAAA", "tag": "div", "text": "No results"}]
|
||||
UNRELATED_ELEMENTS = [{"id": "AAAA", "tag": "div", "text": "Something completely different"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_location_no_results_option_falls_through_to_llm() -> None:
|
||||
"""When the single option doesn't contain the input text, fall through to LLM."""
|
||||
frame = _mock_frame(locator_count=1)
|
||||
skyvern_el = _mock_skyvern_element(frame)
|
||||
inc_scrape = _mock_incremental_scrape(NO_RESULT_ELEMENTS)
|
||||
|
||||
llm_response = {
|
||||
"auto_completion_attempt": False,
|
||||
"relevance_float": 0.0,
|
||||
"id": "",
|
||||
"direct_searching": True,
|
||||
"reasoning": "No results shown",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
|
||||
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
|
||||
),
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.IncrementalScrapePage",
|
||||
return_value=inc_scrape,
|
||||
),
|
||||
patch("skyvern.webeye.actions.handler.app") as mock_app,
|
||||
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
|
||||
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
|
||||
):
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
|
||||
mock_app.AGENT_FUNCTION = MagicMock()
|
||||
mock_prompt.load_prompt.return_value = "mocked prompt"
|
||||
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
|
||||
|
||||
await choose_auto_completion_dropdown(
|
||||
context=_make_location_context(),
|
||||
page=MagicMock(),
|
||||
scraped_page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="123 Main St",
|
||||
skyvern_element=skyvern_el,
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
is_location_input=True,
|
||||
)
|
||||
|
||||
# LLM should be called because "No results" doesn't contain "123 Main St"
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_location_unrelated_option_falls_through_to_llm() -> None:
|
||||
"""When the single option doesn't contain the input text, fall through to LLM."""
|
||||
frame = _mock_frame(locator_count=1)
|
||||
skyvern_el = _mock_skyvern_element(frame)
|
||||
inc_scrape = _mock_incremental_scrape(UNRELATED_ELEMENTS)
|
||||
|
||||
llm_response = {
|
||||
"auto_completion_attempt": True,
|
||||
"relevance_float": 0.5,
|
||||
"id": "AAAA",
|
||||
"direct_searching": False,
|
||||
"reasoning": "Only option available",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.SkyvernFrame.create_instance",
|
||||
new=AsyncMock(return_value=MagicMock(safe_wait_for_animation_end=AsyncMock())),
|
||||
),
|
||||
patch(
|
||||
"skyvern.webeye.actions.handler.IncrementalScrapePage",
|
||||
return_value=inc_scrape,
|
||||
),
|
||||
patch("skyvern.webeye.actions.handler.app") as mock_app,
|
||||
patch("skyvern.webeye.actions.handler.prompt_engine") as mock_prompt,
|
||||
patch("skyvern.webeye.actions.handler.skyvern_context") as mock_ctx,
|
||||
):
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock(return_value=llm_response)
|
||||
mock_app.AGENT_FUNCTION = MagicMock()
|
||||
mock_prompt.load_prompt.return_value = "mocked prompt"
|
||||
mock_ctx.ensure_context.return_value = MagicMock(tz_info=UTC)
|
||||
|
||||
await choose_auto_completion_dropdown(
|
||||
context=_make_location_context(),
|
||||
page=MagicMock(),
|
||||
scraped_page=MagicMock(),
|
||||
dom=MagicMock(),
|
||||
text="123 Main St",
|
||||
skyvern_element=skyvern_el,
|
||||
step=_STEP,
|
||||
task=_TASK,
|
||||
is_location_input=True,
|
||||
)
|
||||
|
||||
# LLM should be called because option doesn't contain the input
|
||||
mock_app.AUTO_COMPLETION_LLM_API_HANDLER.assert_awaited_once()
|
||||
265
tests/unit/test_batch_action_queries.py
Normal file
265
tests/unit/test_batch_action_queries.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Tests for batch action query correctness in transform_workflow_run.py.
|
||||
|
||||
Verifies that the transform layer produces chronologically ordered actions
|
||||
per task for script generation, even though get_tasks_actions returns
|
||||
descending order (for the timeline UI).
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
|
||||
from skyvern.webeye.actions.actions import ClickAction, ExtractAction, InputTextAction
|
||||
|
||||
|
||||
def _make_action(
|
||||
action_cls: type,
|
||||
action_id: str,
|
||||
task_id: str,
|
||||
element_id: str | None = None,
|
||||
**kwargs: object,
|
||||
) -> MagicMock:
|
||||
"""Create a real Action instance for use in tests."""
|
||||
action = action_cls(
|
||||
action_id=action_id,
|
||||
task_id=task_id,
|
||||
element_id=element_id
|
||||
if element_id is not None
|
||||
else ("elem_" + action_id if action_cls != ExtractAction else None),
|
||||
**({"text": "hello"} if action_cls == InputTextAction else {}),
|
||||
**kwargs,
|
||||
)
|
||||
return action
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_actions_preserve_per_task_ordering() -> None:
|
||||
"""
|
||||
Regression test: transform_workflow_run must produce actions in ascending
|
||||
chronological order per task for script generation.
|
||||
|
||||
get_tasks_actions returns DESC order (for timeline UI). The transform
|
||||
layer reverses to ASC. This test mocks DESC input and verifies ASC output.
|
||||
"""
|
||||
mock_workflow_run_resp = MagicMock()
|
||||
mock_workflow_run_resp.run_request = MagicMock()
|
||||
mock_workflow_run_resp.run_request.workflow_id = "wpid_test"
|
||||
mock_workflow_run_resp.run_request.model_dump = MagicMock(
|
||||
return_value={"workflow_id": "wpid_test", "parameters": {}}
|
||||
)
|
||||
|
||||
def_block_a = MagicMock()
|
||||
def_block_a.block_type = "task"
|
||||
def_block_a.label = "block_a"
|
||||
def_block_a.model_dump = MagicMock(return_value={"block_type": "task", "label": "block_a"})
|
||||
|
||||
def_block_b = MagicMock()
|
||||
def_block_b.block_type = "task"
|
||||
def_block_b.label = "block_b"
|
||||
def_block_b.model_dump = MagicMock(return_value={"block_type": "task", "label": "block_b"})
|
||||
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.model_dump = MagicMock(return_value={"workflow_id": "wf_1"})
|
||||
mock_workflow.workflow_definition.blocks = [def_block_a, def_block_b]
|
||||
|
||||
run_block_a = MagicMock()
|
||||
run_block_a.workflow_run_block_id = "wfrb_a"
|
||||
run_block_a.parent_workflow_run_block_id = None
|
||||
run_block_a.block_type = "task"
|
||||
run_block_a.label = "block_a"
|
||||
run_block_a.task_id = "task_a"
|
||||
run_block_a.status = "completed"
|
||||
run_block_a.output = {}
|
||||
run_block_a.created_at = 1
|
||||
|
||||
run_block_b = MagicMock()
|
||||
run_block_b.workflow_run_block_id = "wfrb_b"
|
||||
run_block_b.parent_workflow_run_block_id = None
|
||||
run_block_b.block_type = "task"
|
||||
run_block_b.label = "block_b"
|
||||
run_block_b.task_id = "task_b"
|
||||
run_block_b.status = "completed"
|
||||
run_block_b.output = {}
|
||||
run_block_b.created_at = 2
|
||||
|
||||
mock_task_a = MagicMock()
|
||||
mock_task_a.task_id = "task_a"
|
||||
mock_task_a.model_dump = MagicMock(return_value={"task_id": "task_a"})
|
||||
|
||||
mock_task_b = MagicMock()
|
||||
mock_task_b.task_id = "task_b"
|
||||
mock_task_b.model_dump = MagicMock(return_value={"task_id": "task_b"})
|
||||
|
||||
# Actions in chronological order:
|
||||
# task_a: click (t=1), input_text (t=3)
|
||||
# task_b: click (t=2), extract (t=4)
|
||||
action_a_click = _make_action(ClickAction, action_id="a_click", task_id="task_a", element_id="el_1")
|
||||
action_b_click = _make_action(ClickAction, action_id="b_click", task_id="task_b", element_id="el_2")
|
||||
action_a_input = _make_action(InputTextAction, action_id="a_input", task_id="task_a", element_id="el_3")
|
||||
action_b_extract = _make_action(ExtractAction, action_id="b_extract", task_id="task_b", element_id=None)
|
||||
|
||||
# get_tasks_actions returns DESC order (newest first) — matching real DB behavior
|
||||
all_actions_descending = [action_b_extract, action_a_input, action_b_click, action_a_click]
|
||||
|
||||
with (
|
||||
patch("skyvern.services.workflow_service.get_workflow_run_response", new_callable=AsyncMock) as mock_get_wfr,
|
||||
patch("skyvern.core.script_generations.transform_workflow_run.app") as mock_app,
|
||||
):
|
||||
mock_get_wfr.return_value = mock_workflow_run_resp
|
||||
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block_a, run_block_b])
|
||||
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task_a, mock_task_b])
|
||||
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=all_actions_descending)
|
||||
|
||||
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
|
||||
|
||||
# After reverse, task_a actions must be in chronological order: click then input_text
|
||||
task_a_actions = result.actions_by_task["task_a"]
|
||||
task_a_ids = [a["action_id"] for a in task_a_actions]
|
||||
assert task_a_ids == ["a_click", "a_input"], f"task_a actions out of order: {task_a_ids}"
|
||||
assert task_a_actions[0]["action_type"] == "click"
|
||||
assert task_a_actions[1]["action_type"] == "input_text"
|
||||
|
||||
# task_b actions must be in chronological order: click then extract
|
||||
task_b_actions = result.actions_by_task["task_b"]
|
||||
task_b_ids = [a["action_id"] for a in task_b_actions]
|
||||
assert task_b_ids == ["b_click", "b_extract"], f"task_b actions out of order: {task_b_ids}"
|
||||
assert task_b_actions[0]["action_type"] == "click"
|
||||
assert task_b_actions[1]["action_type"] == "extract"
|
||||
|
||||
# No cross-contamination between tasks
|
||||
assert set(task_a_ids) == {"a_click", "a_input"}
|
||||
assert set(task_b_ids) == {"b_click", "b_extract"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_actions_without_reverse_would_be_wrong() -> None:
|
||||
"""
|
||||
Prove that without the reverse() call, DESC input from get_tasks_actions
|
||||
would produce wrong ordering in script generation output.
|
||||
|
||||
If someone removes the reverse(), this test catches it.
|
||||
"""
|
||||
mock_workflow_run_resp = MagicMock()
|
||||
mock_workflow_run_resp.run_request = MagicMock()
|
||||
mock_workflow_run_resp.run_request.workflow_id = "wpid_test"
|
||||
mock_workflow_run_resp.run_request.model_dump = MagicMock(
|
||||
return_value={"workflow_id": "wpid_test", "parameters": {}}
|
||||
)
|
||||
|
||||
def_block = MagicMock()
|
||||
def_block.block_type = "task"
|
||||
def_block.label = "my_block"
|
||||
def_block.model_dump = MagicMock(return_value={"block_type": "task", "label": "my_block"})
|
||||
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.model_dump = MagicMock(return_value={"workflow_id": "wf_1"})
|
||||
mock_workflow.workflow_definition.blocks = [def_block]
|
||||
|
||||
run_block = MagicMock()
|
||||
run_block.workflow_run_block_id = "wfrb_1"
|
||||
run_block.parent_workflow_run_block_id = None
|
||||
run_block.block_type = "task"
|
||||
run_block.label = "my_block"
|
||||
run_block.task_id = "task_1"
|
||||
run_block.status = "completed"
|
||||
run_block.output = {}
|
||||
run_block.created_at = 1
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.task_id = "task_1"
|
||||
mock_task.model_dump = MagicMock(return_value={"task_id": "task_1"})
|
||||
|
||||
# Chronological order: click -> input -> extract
|
||||
# DB returns DESC: extract -> input -> click
|
||||
action_click = _make_action(ClickAction, action_id="act_1_click", task_id="task_1", element_id="el_1")
|
||||
action_input = _make_action(InputTextAction, action_id="act_2_input", task_id="task_1", element_id="el_2")
|
||||
action_extract = _make_action(ExtractAction, action_id="act_3_extract", task_id="task_1", element_id=None)
|
||||
|
||||
# DESC order from DB (newest first)
|
||||
actions_descending = [action_extract, action_input, action_click]
|
||||
|
||||
with (
|
||||
patch("skyvern.services.workflow_service.get_workflow_run_response", new_callable=AsyncMock) as mock_get_wfr,
|
||||
patch("skyvern.core.script_generations.transform_workflow_run.app") as mock_app,
|
||||
):
|
||||
mock_get_wfr.return_value = mock_workflow_run_resp
|
||||
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
|
||||
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=actions_descending)
|
||||
|
||||
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
|
||||
|
||||
# After reverse, output must be chronological: click, input, extract
|
||||
actions = result.actions_by_task["task_1"]
|
||||
action_ids = [a["action_id"] for a in actions]
|
||||
assert action_ids == ["act_1_click", "act_2_input", "act_3_extract"], (
|
||||
f"Actions should be in chronological order after reverse, got: {action_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_actions_preserve_none_element_id() -> None:
|
||||
"""
|
||||
Regression test: hydrate_action must be called WITHOUT empty_element_id=True,
|
||||
so that None element_ids remain None (matching get_task_actions_hydrated behavior).
|
||||
|
||||
Previously get_tasks_actions used hydrate_action(action, empty_element_id=True)
|
||||
which silently converted None element_ids to empty strings.
|
||||
"""
|
||||
mock_workflow_run_resp = MagicMock()
|
||||
mock_workflow_run_resp.run_request = MagicMock()
|
||||
mock_workflow_run_resp.run_request.workflow_id = "wpid_test"
|
||||
mock_workflow_run_resp.run_request.model_dump = MagicMock(
|
||||
return_value={"workflow_id": "wpid_test", "parameters": {}}
|
||||
)
|
||||
|
||||
def_block = MagicMock()
|
||||
def_block.block_type = "extraction"
|
||||
def_block.label = "extract_block"
|
||||
def_block.model_dump = MagicMock(return_value={"block_type": "extraction", "label": "extract_block"})
|
||||
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.model_dump = MagicMock(return_value={"workflow_id": "wf_1"})
|
||||
mock_workflow.workflow_definition.blocks = [def_block]
|
||||
|
||||
run_block = MagicMock()
|
||||
run_block.workflow_run_block_id = "wfrb_1"
|
||||
run_block.parent_workflow_run_block_id = None
|
||||
run_block.block_type = "extraction"
|
||||
run_block.label = "extract_block"
|
||||
run_block.task_id = "task_1"
|
||||
run_block.status = "completed"
|
||||
run_block.output = {}
|
||||
run_block.created_at = 1
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.task_id = "task_1"
|
||||
mock_task.model_dump = MagicMock(return_value={"task_id": "task_1"})
|
||||
|
||||
# ExtractAction has element_id=None (extracts don't target a specific element)
|
||||
action_extract = _make_action(ExtractAction, action_id="act_extract", task_id="task_1", element_id=None)
|
||||
assert action_extract.element_id is None
|
||||
|
||||
with (
|
||||
patch("skyvern.services.workflow_service.get_workflow_run_response", new_callable=AsyncMock) as mock_get_wfr,
|
||||
patch("skyvern.core.script_generations.transform_workflow_run.app") as mock_app,
|
||||
):
|
||||
mock_get_wfr.return_value = mock_workflow_run_resp
|
||||
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
|
||||
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=[action_extract])
|
||||
|
||||
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
|
||||
|
||||
actions = result.actions_by_task["task_1"]
|
||||
assert len(actions) == 1
|
||||
# element_id must remain None, NOT converted to ""
|
||||
assert actions[0]["element_id"] is None, (
|
||||
f"element_id should be None but got {actions[0]['element_id']!r}. "
|
||||
"This indicates hydrate_action was called with empty_element_id=True"
|
||||
)
|
||||
153
tests/unit/test_branch_criteria.py
Normal file
153
tests/unit/test_branch_criteria.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.workflow.exceptions import FailedToFormatJinjaStyleParameter, MissingJinjaVariables
|
||||
from skyvern.forge.sdk.workflow.models.block import BranchEvaluationContext, JinjaBranchCriteria
|
||||
|
||||
|
||||
class FakeWorkflowRunContext:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
values: dict,
|
||||
secrets: dict | None = None,
|
||||
include_secrets_in_templates: bool = False,
|
||||
block_metadata: dict[str, dict] | None = None,
|
||||
) -> None:
|
||||
self.values = dict(values)
|
||||
self.secrets = secrets or {}
|
||||
self.include_secrets_in_templates = include_secrets_in_templates
|
||||
self._block_metadata = block_metadata or {}
|
||||
|
||||
# Minimal workflow identifiers
|
||||
self.workflow_title = "wf-title"
|
||||
self.workflow_id = "wf-id"
|
||||
self.workflow_permanent_id = "wf-perm-id"
|
||||
self.workflow_run_id = "wf-run-id"
|
||||
|
||||
def get_block_metadata(self, label: str) -> dict:
|
||||
return dict(self._block_metadata.get(label, {}))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_branch_criteria_evaluates_truthy_with_workflow_context():
|
||||
fake_ctx = FakeWorkflowRunContext(
|
||||
values={"params": {"foo": "bar"}, "extra": "value"},
|
||||
block_metadata={"conditional": {"current_index": 1, "custom": "meta"}},
|
||||
)
|
||||
branch_ctx = BranchEvaluationContext(
|
||||
workflow_run_context=fake_ctx, # ensures template_data matches block parameter rendering
|
||||
block_label="conditional",
|
||||
)
|
||||
criteria = JinjaBranchCriteria(expression="{{ params.foo == 'bar' and current_index == 1 }}")
|
||||
|
||||
assert await criteria.evaluate(branch_ctx) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_branch_criteria_raises_on_missing_variable_strict(monkeypatch):
|
||||
monkeypatch.setattr(settings, "WORKFLOW_TEMPLATING_STRICTNESS", "strict")
|
||||
branch_ctx = BranchEvaluationContext()
|
||||
criteria = JinjaBranchCriteria(expression="{{ missing_value }}")
|
||||
|
||||
with pytest.raises(MissingJinjaVariables):
|
||||
await criteria.evaluate(branch_ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_branch_criteria_raises_on_template_error():
|
||||
branch_ctx = BranchEvaluationContext()
|
||||
criteria = JinjaBranchCriteria(expression="{% for %}") # invalid Jinja syntax
|
||||
|
||||
with pytest.raises(FailedToFormatJinjaStyleParameter):
|
||||
await criteria.evaluate(branch_ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"expression,expected",
|
||||
[
|
||||
# Boolean-like strings (case insensitive)
|
||||
("{{ 'true' }}", True),
|
||||
("{{ 'True' }}", True),
|
||||
("{{ 'TRUE' }}", True),
|
||||
("{{ 'false' }}", False),
|
||||
("{{ 'False' }}", False),
|
||||
("{{ 'FALSE' }}", False),
|
||||
# Numeric strings
|
||||
("{{ '1' }}", True),
|
||||
("{{ '0' }}", False),
|
||||
("{{ '42' }}", True),
|
||||
("{{ '-1' }}", True),
|
||||
("{{ '0.0' }}", False),
|
||||
("{{ '0.1' }}", True),
|
||||
("{{ '-0.5' }}", True),
|
||||
# Yes/No variants
|
||||
("{{ 'yes' }}", True),
|
||||
("{{ 'Yes' }}", True),
|
||||
("{{ 'YES' }}", True),
|
||||
("{{ 'y' }}", True),
|
||||
("{{ 'Y' }}", True),
|
||||
("{{ 'no' }}", False),
|
||||
("{{ 'No' }}", False),
|
||||
("{{ 'NO' }}", False),
|
||||
("{{ 'n' }}", False),
|
||||
("{{ 'N' }}", False),
|
||||
# On/Off
|
||||
("{{ 'on' }}", True),
|
||||
("{{ 'ON' }}", True),
|
||||
("{{ 'off' }}", False),
|
||||
("{{ 'OFF' }}", False),
|
||||
# Null variants
|
||||
("{{ 'null' }}", False),
|
||||
("{{ 'Null' }}", False),
|
||||
("{{ 'NULL' }}", False),
|
||||
("{{ 'none' }}", False),
|
||||
("{{ 'None' }}", False),
|
||||
# Empty and whitespace
|
||||
("{{ '' }}", False),
|
||||
("{{ ' ' }}", False),
|
||||
("{{ '\t\n' }}", False),
|
||||
# Arbitrary strings (non-empty = truthy)
|
||||
("{{ 'some text' }}", True),
|
||||
("{{ 'anything' }}", True),
|
||||
# Direct boolean comparisons (common use case)
|
||||
("{{ 5 > 3 }}", True),
|
||||
("{{ 1 == 0 }}", False),
|
||||
],
|
||||
)
|
||||
async def test_jinja_branch_criteria_truthy_falsy_evaluation(expression: str, expected: bool):
|
||||
"""Test that rendered template strings are properly evaluated as boolean."""
|
||||
fake_ctx = FakeWorkflowRunContext(values={})
|
||||
branch_ctx = BranchEvaluationContext(workflow_run_context=fake_ctx, block_label="test")
|
||||
criteria = JinjaBranchCriteria(expression=expression)
|
||||
|
||||
result = await criteria.evaluate(branch_ctx)
|
||||
assert result is expected, f"Expression {expression} should evaluate to {expected}, got {result}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_branch_criteria_with_variable_comparison():
|
||||
"""Test realistic scenario with variable comparisons."""
|
||||
fake_ctx = FakeWorkflowRunContext(
|
||||
values={
|
||||
"comment_count": 150,
|
||||
"threshold": 100,
|
||||
"status": "active",
|
||||
}
|
||||
)
|
||||
branch_ctx = BranchEvaluationContext(workflow_run_context=fake_ctx, block_label="test")
|
||||
|
||||
# Numeric comparison
|
||||
criteria = JinjaBranchCriteria(expression="{{ comment_count > threshold }}")
|
||||
assert await criteria.evaluate(branch_ctx) is True
|
||||
|
||||
# String comparison
|
||||
criteria = JinjaBranchCriteria(expression="{{ status == 'active' }}")
|
||||
assert await criteria.evaluate(branch_ctx) is True
|
||||
|
||||
# Combined logic
|
||||
criteria = JinjaBranchCriteria(expression="{{ comment_count > threshold and status == 'active' }}")
|
||||
assert await criteria.evaluate(branch_ctx) is True
|
||||
119
tests/unit/test_bulk_artifact_creation.py
Normal file
119
tests/unit/test_bulk_artifact_creation.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Unit tests for bulk artifact creation functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.artifact.manager import ArtifactBatchData, BulkArtifactCreationRequest
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.db.models import ArtifactModel
|
||||
|
||||
|
||||
def test_artifact_batch_data_with_data():
|
||||
"""Test ArtifactBatchData with data field."""
|
||||
model = ArtifactModel(
|
||||
artifact_id="test-1",
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
uri="s3://bucket/test",
|
||||
organization_id="org-1",
|
||||
)
|
||||
|
||||
batch_data = ArtifactBatchData(
|
||||
artifact_model=model,
|
||||
data=b"test data",
|
||||
)
|
||||
|
||||
assert batch_data.artifact_model == model
|
||||
assert batch_data.data == b"test data"
|
||||
assert batch_data.path is None
|
||||
|
||||
|
||||
def test_artifact_batch_data_with_path():
|
||||
"""Test ArtifactBatchData with path field."""
|
||||
model = ArtifactModel(
|
||||
artifact_id="test-1",
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
uri="s3://bucket/test",
|
||||
organization_id="org-1",
|
||||
)
|
||||
|
||||
batch_data = ArtifactBatchData(
|
||||
artifact_model=model,
|
||||
path="/tmp/test.png",
|
||||
)
|
||||
|
||||
assert batch_data.artifact_model == model
|
||||
assert batch_data.data is None
|
||||
assert batch_data.path == "/tmp/test.png"
|
||||
|
||||
|
||||
def test_artifact_batch_data_with_both_raises_error():
|
||||
"""Test that ArtifactBatchData raises error when both data and path are provided."""
|
||||
model = ArtifactModel(
|
||||
artifact_id="test-1",
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
uri="s3://bucket/test",
|
||||
organization_id="org-1",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot specify both data and path"):
|
||||
ArtifactBatchData(
|
||||
artifact_model=model,
|
||||
data=b"test data",
|
||||
path="/tmp/test.png",
|
||||
)
|
||||
|
||||
|
||||
def test_bulk_artifact_creation_request():
|
||||
"""Test BulkArtifactCreationRequest structure."""
|
||||
model1 = ArtifactModel(
|
||||
artifact_id="test-1",
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
uri="s3://bucket/test1",
|
||||
organization_id="org-1",
|
||||
)
|
||||
model2 = ArtifactModel(
|
||||
artifact_id="test-2",
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
uri="s3://bucket/test2",
|
||||
organization_id="org-1",
|
||||
)
|
||||
|
||||
request = BulkArtifactCreationRequest(
|
||||
artifacts=[
|
||||
ArtifactBatchData(artifact_model=model1, data=b"data1"),
|
||||
ArtifactBatchData(artifact_model=model2, data=b"data2"),
|
||||
],
|
||||
primary_key="task-123",
|
||||
)
|
||||
|
||||
assert len(request.artifacts) == 2
|
||||
assert request.primary_key == "task-123"
|
||||
assert request.artifacts[0].artifact_model.artifact_id == "test-1"
|
||||
assert request.artifacts[1].artifact_model.artifact_id == "test-2"
|
||||
|
||||
|
||||
def test_bulk_artifact_creation_performance_benefit():
|
||||
"""
|
||||
Test to verify that bulk creation reduces database calls.
|
||||
This is a conceptual test to document the performance improvement.
|
||||
"""
|
||||
# Before optimization: Creating N artifacts = N database INSERT calls
|
||||
# After optimization: Creating N artifacts = 1 bulk INSERT call
|
||||
|
||||
num_artifacts = 10
|
||||
|
||||
# Simulate old approach (N individual inserts)
|
||||
individual_insert_count = num_artifacts
|
||||
|
||||
# Simulate new approach (1 bulk insert)
|
||||
bulk_insert_count = 1
|
||||
|
||||
# Assert that bulk insert is more efficient
|
||||
assert bulk_insert_count < individual_insert_count
|
||||
|
||||
# The reduction ratio
|
||||
reduction_ratio = individual_insert_count / bulk_insert_count
|
||||
assert reduction_ratio == num_artifacts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
343
tests/unit/test_click_prompt_parameterization.py
Normal file
343
tests/unit/test_click_prompt_parameterization.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
Tests for click prompt parameterization in cached script generation.
|
||||
|
||||
When generating cached scripts, click action prompts (intention/reasoning) should
|
||||
replace literal parameter values with f-string references to context.parameters[...],
|
||||
so that re-runs with different values produce correct behavior.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from skyvern.core.script_generations.generate_script import (
|
||||
MIN_PARAM_VALUE_LENGTH_FOR_PROMPT_SUB,
|
||||
_action_to_stmt,
|
||||
_build_parameterized_prompt_cst,
|
||||
_build_value_to_param_lookup,
|
||||
)
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
|
||||
|
||||
def _make_action(
|
||||
action_type: str,
|
||||
field_name: str | None = None,
|
||||
text: str = "",
|
||||
option: str = "",
|
||||
file_url: str = "",
|
||||
) -> dict[str, Any]:
|
||||
action: dict[str, Any] = {"action_type": action_type}
|
||||
if field_name:
|
||||
action["field_name"] = field_name
|
||||
if text:
|
||||
action["text"] = text
|
||||
if option:
|
||||
action["option"] = option
|
||||
if file_url:
|
||||
action["file_url"] = file_url
|
||||
return action
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_value_to_param_lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildValueToParamLookup:
|
||||
def test_collects_input_text_values(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="patient_id", text="542-641-668"),
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert lookup == {"542-641-668": "patient_id"}
|
||||
|
||||
def test_collects_select_option_values(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(ActionType.SELECT_OPTION, field_name="state", option="California"),
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert lookup == {"California": "state"}
|
||||
|
||||
def test_collects_upload_file_values(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(
|
||||
ActionType.UPLOAD_FILE,
|
||||
field_name="document",
|
||||
file_url="https://example.com/report.pdf",
|
||||
),
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert lookup == {"https://example.com/report.pdf": "document"}
|
||||
|
||||
def test_skips_actions_without_field_name(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(ActionType.INPUT_TEXT, text="some value without field name"),
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert lookup == {}
|
||||
|
||||
def test_skips_short_values(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="flag", text="No"),
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="code", text="CA"),
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="num", text="1"),
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert lookup == {}
|
||||
|
||||
def test_boundary_value_at_min_length(self) -> None:
|
||||
"""Values at exactly MIN_PARAM_VALUE_LENGTH_FOR_PROMPT_SUB should be included."""
|
||||
value = "x" * MIN_PARAM_VALUE_LENGTH_FOR_PROMPT_SUB
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="field", text=value),
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert value in lookup
|
||||
|
||||
def test_sorted_by_descending_length(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="short_field", text="abcd"),
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="long_field", text="abcdefghij"),
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="mid_field", text="abcdef"),
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
keys = list(lookup.keys())
|
||||
assert keys == ["abcdefghij", "abcdef", "abcd"]
|
||||
|
||||
def test_first_writer_wins_on_duplicate_values(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="first_field", text="same-value"),
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="second_field", text="same-value"),
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert lookup["same-value"] == "first_field"
|
||||
|
||||
def test_skips_click_actions(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
{"action_type": ActionType.CLICK, "field_name": "click_field", "text": "some text"},
|
||||
]
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert lookup == {}
|
||||
|
||||
def test_multiple_tasks(self) -> None:
|
||||
actions_by_task = {
|
||||
"task-1": [
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="patient_id", text="542-641-668"),
|
||||
],
|
||||
"task-2": [
|
||||
_make_action(ActionType.INPUT_TEXT, field_name="doctor_name", text="Dr. Smith"),
|
||||
],
|
||||
}
|
||||
lookup = _build_value_to_param_lookup(actions_by_task)
|
||||
assert lookup == {"542-641-668": "patient_id", "Dr. Smith": "doctor_name"}
|
||||
|
||||
def test_empty_actions(self) -> None:
|
||||
lookup = _build_value_to_param_lookup({})
|
||||
assert lookup == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_parameterized_prompt_cst
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildParameterizedPromptCst:
|
||||
def test_returns_none_when_no_matches(self) -> None:
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"Click the submit button",
|
||||
{"542-641-668": "patient_id"},
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_single_substitution(self) -> None:
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"Which card corresponds to the referral for ID 542-641-668?",
|
||||
{"542-641-668": "patient_id"},
|
||||
)
|
||||
assert result is not None
|
||||
assert isinstance(result, cst.FormattedString)
|
||||
code = cst.Module(body=[]).code_for_node(result)
|
||||
assert "context.parameters" in code
|
||||
assert "patient_id" in code
|
||||
assert "542-641-668" not in code
|
||||
|
||||
def test_multiple_substitutions(self) -> None:
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"Find patient 542-641-668 with doctor Dr. Smith",
|
||||
{"542-641-668": "patient_id", "Dr. Smith": "doctor_name"},
|
||||
)
|
||||
assert result is not None
|
||||
code = cst.Module(body=[]).code_for_node(result)
|
||||
assert "patient_id" in code
|
||||
assert "doctor_name" in code
|
||||
assert "542-641-668" not in code
|
||||
assert "Dr. Smith" not in code
|
||||
|
||||
def test_substitution_at_start(self) -> None:
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"542-641-668 is the patient ID to search for",
|
||||
{"542-641-668": "patient_id"},
|
||||
)
|
||||
assert result is not None
|
||||
parts = result.parts
|
||||
# First part should be the expression (substitution at start)
|
||||
assert isinstance(parts[0], cst.FormattedStringExpression)
|
||||
|
||||
def test_substitution_at_end(self) -> None:
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"Search for patient 542-641-668",
|
||||
{"542-641-668": "patient_id"},
|
||||
)
|
||||
assert result is not None
|
||||
parts = result.parts
|
||||
# Last part should be the expression (substitution at end)
|
||||
assert isinstance(parts[-1], cst.FormattedStringExpression)
|
||||
|
||||
def test_empty_intention(self) -> None:
|
||||
result = _build_parameterized_prompt_cst("", {"542-641-668": "patient_id"})
|
||||
assert result is None
|
||||
|
||||
def test_empty_lookup(self) -> None:
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"Which card corresponds to the referral for ID 542-641-668?",
|
||||
{},
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_longer_match_preferred_over_shorter(self) -> None:
|
||||
"""When values overlap, the longer value (sorted first) takes precedence."""
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"Enter 542-641-668-999 here",
|
||||
{
|
||||
"542-641-668-999": "full_id",
|
||||
"542-641-668": "partial_id",
|
||||
},
|
||||
)
|
||||
assert result is not None
|
||||
code = cst.Module(body=[]).code_for_node(result)
|
||||
assert "full_id" in code
|
||||
assert "partial_id" not in code
|
||||
|
||||
def test_generates_valid_fstring_syntax(self) -> None:
|
||||
"""The generated f-string should be parseable Python."""
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"Which card area corresponds to ID 542-641-668?",
|
||||
{"542-641-668": "patient_id"},
|
||||
)
|
||||
assert result is not None
|
||||
code = cst.Module(body=[]).code_for_node(result)
|
||||
# Should be a valid f-string — verify it starts with f" or f'
|
||||
assert code.startswith("f'") or code.startswith('f"')
|
||||
# The full expression should be compilable
|
||||
compile(code, "<test>", "eval")
|
||||
|
||||
def test_repeated_value_in_intention(self) -> None:
|
||||
"""If the same value appears twice, both occurrences should be replaced."""
|
||||
result = _build_parameterized_prompt_cst(
|
||||
"Compare 542-641-668 with 542-641-668",
|
||||
{"542-641-668": "patient_id"},
|
||||
)
|
||||
assert result is not None
|
||||
code = cst.Module(body=[]).code_for_node(result)
|
||||
# The literal should not appear at all
|
||||
assert "542-641-668" not in code
|
||||
# context.parameters should appear twice (once per occurrence)
|
||||
assert code.count("context.parameters") == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: _action_to_stmt with value_to_param
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestActionToStmtClickParameterization:
|
||||
"""End-to-end tests exercising _action_to_stmt for click actions."""
|
||||
|
||||
def _render(self, stmt: cst.BaseStatement) -> str:
|
||||
return cst.Module(body=[stmt]).code
|
||||
|
||||
def test_click_prompt_parameterized(self) -> None:
|
||||
"""Click action with matching value in intention gets an f-string prompt."""
|
||||
act: dict[str, Any] = {
|
||||
"action_type": "click",
|
||||
"xpath": "//div[@class='card']",
|
||||
"intention": "Which card corresponds to the referral for ID 542-641-668?",
|
||||
}
|
||||
task: dict[str, Any] = {}
|
||||
value_to_param = {"542-641-668": "patient_id"}
|
||||
|
||||
stmt = _action_to_stmt(act, task, value_to_param=value_to_param)
|
||||
code = self._render(stmt)
|
||||
|
||||
assert "context.parameters" in code
|
||||
assert "patient_id" in code
|
||||
assert "542-641-668" not in code
|
||||
# Should be an f-string
|
||||
assert "f'" in code or 'f"' in code
|
||||
|
||||
def test_click_prompt_literal_when_no_lookup(self) -> None:
|
||||
"""Click action without value_to_param produces a plain string prompt."""
|
||||
act: dict[str, Any] = {
|
||||
"action_type": "click",
|
||||
"xpath": "//div[@class='card']",
|
||||
"intention": "Click the submit button",
|
||||
}
|
||||
task: dict[str, Any] = {}
|
||||
|
||||
stmt = _action_to_stmt(act, task, value_to_param=None)
|
||||
code = self._render(stmt)
|
||||
|
||||
assert "Click the submit button" in code
|
||||
assert "context.parameters" not in code
|
||||
|
||||
def test_click_prompt_literal_when_no_match(self) -> None:
|
||||
"""Click action with non-matching lookup produces a plain string prompt."""
|
||||
act: dict[str, Any] = {
|
||||
"action_type": "click",
|
||||
"xpath": "//button",
|
||||
"intention": "Click the submit button",
|
||||
}
|
||||
task: dict[str, Any] = {}
|
||||
value_to_param = {"542-641-668": "patient_id"}
|
||||
|
||||
stmt = _action_to_stmt(act, task, value_to_param=value_to_param)
|
||||
code = self._render(stmt)
|
||||
|
||||
assert "Click the submit button" in code
|
||||
assert "context.parameters" not in code
|
||||
|
||||
def test_fill_action_unaffected_by_value_to_param(self) -> None:
|
||||
"""Fill actions should still use the field_name mechanism, not value_to_param."""
|
||||
act: dict[str, Any] = {
|
||||
"action_type": "input_text",
|
||||
"xpath": "//input[@name='search']",
|
||||
"text": "542-641-668",
|
||||
"field_name": "patient_id",
|
||||
}
|
||||
task: dict[str, Any] = {}
|
||||
value_to_param = {"542-641-668": "patient_id"}
|
||||
|
||||
stmt = _action_to_stmt(act, task, value_to_param=value_to_param)
|
||||
code = self._render(stmt)
|
||||
|
||||
# Should use context.parameters via the field_name mechanism, not f-string
|
||||
assert "context.parameters" in code
|
||||
assert "patient_id" in code
|
||||
545
tests/unit/test_compute_conditional_scopes.py
Normal file
545
tests/unit/test_compute_conditional_scopes.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""Tests for compute_conditional_scopes() function.
|
||||
|
||||
This function maps each block label to the conditional block label whose scope it belongs to.
|
||||
It handles merge-point detection, nested conditionals, and deduplication of branch targets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from skyvern.forge.sdk.workflow.models.block import (
|
||||
BranchCondition,
|
||||
ConditionalBlock,
|
||||
HttpRequestBlock,
|
||||
TaskBlock,
|
||||
compute_conditional_scopes,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
|
||||
|
||||
|
||||
def _make_output_parameter(key: str) -> OutputParameter:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
return OutputParameter(
|
||||
key=key,
|
||||
parameter_type="output",
|
||||
output_parameter_id=f"op_{key}",
|
||||
workflow_id="wf_test",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _make_task_block(label: str, *, next_block_label: str | None = None) -> TaskBlock:
|
||||
return TaskBlock(
|
||||
label=label,
|
||||
url="https://example.com",
|
||||
output_parameter=_make_output_parameter(label),
|
||||
next_block_label=next_block_label,
|
||||
)
|
||||
|
||||
|
||||
def _make_http_block(label: str, *, next_block_label: str | None = None) -> HttpRequestBlock:
|
||||
return HttpRequestBlock(
|
||||
label=label,
|
||||
url="https://example.com",
|
||||
method="GET",
|
||||
output_parameter=_make_output_parameter(label),
|
||||
next_block_label=next_block_label,
|
||||
)
|
||||
|
||||
|
||||
def _make_conditional_block(
|
||||
label: str,
|
||||
branches: list[tuple[str | None, bool]],
|
||||
*,
|
||||
next_block_label: str | None = None,
|
||||
) -> ConditionalBlock:
|
||||
"""Create a conditional block with the given branches.
|
||||
|
||||
Args:
|
||||
label: Block label
|
||||
branches: List of (next_block_label, is_default) tuples
|
||||
next_block_label: Default next block for the conditional itself (usually None)
|
||||
"""
|
||||
branch_conditions = []
|
||||
for target, is_default in branches:
|
||||
if is_default:
|
||||
branch_conditions.append(BranchCondition(next_block_label=target, is_default=True))
|
||||
else:
|
||||
branch_conditions.append(
|
||||
BranchCondition(
|
||||
next_block_label=target,
|
||||
criteria={"criteria_type": "jinja2_template", "expression": "{{ true }}"},
|
||||
)
|
||||
)
|
||||
return ConditionalBlock(
|
||||
label=label,
|
||||
output_parameter=_make_output_parameter(label),
|
||||
branch_conditions=branch_conditions,
|
||||
next_block_label=next_block_label,
|
||||
)
|
||||
|
||||
|
||||
class TestComputeConditionalScopes:
|
||||
"""Tests for compute_conditional_scopes()."""
|
||||
|
||||
def test_simple_two_branch_conditional_with_merge(self):
|
||||
"""Test a simple conditional with two branches that merge.
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> A -> MergePoint(M)
|
||||
-> Branch2 -> B -> M
|
||||
|
||||
Expected: A and B are scoped to C, M is NOT scoped (merge point).
|
||||
"""
|
||||
block_a = _make_task_block("A", next_block_label="M")
|
||||
block_b = _make_task_block("B", next_block_label="M")
|
||||
block_m = _make_task_block("M")
|
||||
cond = _make_conditional_block("C", [("A", False), ("B", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
"M": block_m,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": "M",
|
||||
"B": "M",
|
||||
"M": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
assert scopes == {"A": "C", "B": "C"}
|
||||
assert "M" not in scopes # M is a merge point
|
||||
|
||||
def test_conditional_with_chain_before_merge(self):
|
||||
"""Test branches with multiple blocks before merge point.
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> A -> B -> MergePoint(M)
|
||||
-> Branch2 -> D -> M
|
||||
|
||||
Expected: A, B, D are scoped to C. M is NOT scoped.
|
||||
"""
|
||||
block_a = _make_task_block("A", next_block_label="B")
|
||||
block_b = _make_task_block("B", next_block_label="M")
|
||||
block_d = _make_task_block("D", next_block_label="M")
|
||||
block_m = _make_task_block("M")
|
||||
cond = _make_conditional_block("C", [("A", False), ("D", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
"D": block_d,
|
||||
"M": block_m,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": "B",
|
||||
"B": "M",
|
||||
"D": "M",
|
||||
"M": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
assert scopes == {"A": "C", "B": "C", "D": "C"}
|
||||
assert "M" not in scopes
|
||||
|
||||
def test_conditional_with_terminal_branches(self):
|
||||
"""Test branches that don't merge (terminate independently).
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> A (terminal)
|
||||
-> Branch2 -> B (terminal)
|
||||
|
||||
Expected: A and B are scoped to C since they don't appear in all branches.
|
||||
"""
|
||||
block_a = _make_task_block("A")
|
||||
block_b = _make_task_block("B")
|
||||
cond = _make_conditional_block("C", [("A", False), ("B", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": None,
|
||||
"B": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
assert scopes == {"A": "C", "B": "C"}
|
||||
|
||||
def test_conditional_all_branches_terminal_none(self):
|
||||
"""Test when all branches have None as target (no blocks to scope).
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> None
|
||||
-> Branch2 -> None
|
||||
|
||||
Expected: No scopes (no blocks in the branches).
|
||||
"""
|
||||
cond = _make_conditional_block("C", [(None, False), (None, True)])
|
||||
|
||||
label_to_block = {"C": cond}
|
||||
default_next_map = {"C": None}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
assert scopes == {}
|
||||
|
||||
def test_multiple_branches_same_target_deduplication(self):
|
||||
"""Test that duplicate branch targets are deduplicated.
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> A -> M
|
||||
-> Branch2 -> A -> M (same as Branch1)
|
||||
-> Branch3 -> B -> M
|
||||
|
||||
With deduplication, unique targets are [A, B], so num_branches = 2.
|
||||
Both chains go to M, so M is a merge point.
|
||||
A appears in only one chain (after dedup), B in another.
|
||||
"""
|
||||
block_a = _make_task_block("A", next_block_label="M")
|
||||
block_b = _make_task_block("B", next_block_label="M")
|
||||
block_m = _make_task_block("M")
|
||||
cond = _make_conditional_block("C", [("A", False), ("A", False), ("B", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
"M": block_m,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": "M",
|
||||
"B": "M",
|
||||
"M": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# A and B are scoped to C, M is the merge point
|
||||
assert scopes == {"A": "C", "B": "C"}
|
||||
assert "M" not in scopes
|
||||
|
||||
def test_nested_conditionals(self):
|
||||
"""Test nested conditionals (conditional inside another's branch).
|
||||
|
||||
Workflow:
|
||||
OuterCond(C1) -> Branch1 -> InnerCond(C2) -> BranchA -> X
|
||||
-> BranchB -> Y
|
||||
-> Branch2 -> Z -> MergePoint(M)
|
||||
|
||||
Expected:
|
||||
- C2 is scoped to C1 (it's in C1's branch)
|
||||
- X and Y are scoped to C2 (inner conditional handles its own branches)
|
||||
- Z is scoped to C1
|
||||
- M might or might not be scoped depending on structure
|
||||
"""
|
||||
block_x = _make_task_block("X")
|
||||
block_y = _make_task_block("Y")
|
||||
block_z = _make_task_block("Z", next_block_label="M")
|
||||
block_m = _make_task_block("M")
|
||||
inner_cond = _make_conditional_block("C2", [("X", False), ("Y", True)])
|
||||
outer_cond = _make_conditional_block("C1", [("C2", False), ("Z", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C1": outer_cond,
|
||||
"C2": inner_cond,
|
||||
"X": block_x,
|
||||
"Y": block_y,
|
||||
"Z": block_z,
|
||||
"M": block_m,
|
||||
}
|
||||
default_next_map = {
|
||||
"C1": None,
|
||||
"C2": None, # Inner conditional doesn't have a default next
|
||||
"X": None,
|
||||
"Y": None,
|
||||
"Z": "M",
|
||||
"M": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# C2 is scoped to C1 (it's in C1's branch, and tracing stops at C2)
|
||||
assert scopes.get("C2") == "C1"
|
||||
# X and Y are scoped to C2 (inner conditional)
|
||||
assert scopes.get("X") == "C2"
|
||||
assert scopes.get("Y") == "C2"
|
||||
# Z is scoped to C1
|
||||
assert scopes.get("Z") == "C1"
|
||||
|
||||
def test_no_conditionals_in_workflow(self):
|
||||
"""Test workflow with no conditional blocks.
|
||||
|
||||
Workflow:
|
||||
A -> B -> C
|
||||
|
||||
Expected: No scopes.
|
||||
"""
|
||||
block_a = _make_task_block("A", next_block_label="B")
|
||||
block_b = _make_task_block("B", next_block_label="C")
|
||||
block_c = _make_task_block("C")
|
||||
|
||||
label_to_block = {
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
"C": block_c,
|
||||
}
|
||||
default_next_map = {
|
||||
"A": "B",
|
||||
"B": "C",
|
||||
"C": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
assert scopes == {}
|
||||
|
||||
def test_conditional_with_single_branch(self):
|
||||
"""Test conditional with effectively one unique branch target.
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> A
|
||||
-> Branch2 -> A (same target, deduplicated)
|
||||
|
||||
After deduplication, num_branches = 1, and A appears in 1/1 chains,
|
||||
making it a "merge point" (appears in all branches).
|
||||
"""
|
||||
block_a = _make_task_block("A")
|
||||
cond = _make_conditional_block("C", [("A", False), ("A", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# A appears in all (1) branch chains, so it's treated as a merge point
|
||||
assert scopes == {}
|
||||
|
||||
def test_three_branch_conditional_partial_merge(self):
|
||||
"""Test three branches where only some merge.
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> A -> M
|
||||
-> Branch2 -> B -> M
|
||||
-> Branch3 -> D (terminal, no merge)
|
||||
|
||||
M appears in 2/3 branches, so it's NOT a merge point.
|
||||
All of A, B, D, M should be scoped to C.
|
||||
"""
|
||||
block_a = _make_task_block("A", next_block_label="M")
|
||||
block_b = _make_task_block("B", next_block_label="M")
|
||||
block_d = _make_task_block("D")
|
||||
block_m = _make_task_block("M")
|
||||
cond = _make_conditional_block("C", [("A", False), ("B", False), ("D", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
"D": block_d,
|
||||
"M": block_m,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": "M",
|
||||
"B": "M",
|
||||
"D": None,
|
||||
"M": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# M only appears in 2/3 branches, so it's still inside the conditional scope
|
||||
assert scopes == {"A": "C", "B": "C", "D": "C", "M": "C"}
|
||||
|
||||
def test_merge_point_with_blocks_after(self):
|
||||
"""Test that blocks after the merge point are not scoped.
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> A -> M -> X -> Y
|
||||
-> Branch2 -> B -> M
|
||||
|
||||
M is the merge point (appears in both chains).
|
||||
X and Y come after M and should NOT be scoped.
|
||||
"""
|
||||
block_a = _make_task_block("A", next_block_label="M")
|
||||
block_b = _make_task_block("B", next_block_label="M")
|
||||
block_m = _make_task_block("M", next_block_label="X")
|
||||
block_x = _make_task_block("X", next_block_label="Y")
|
||||
block_y = _make_task_block("Y")
|
||||
cond = _make_conditional_block("C", [("A", False), ("B", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
"M": block_m,
|
||||
"X": block_x,
|
||||
"Y": block_y,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": "M",
|
||||
"B": "M",
|
||||
"M": "X",
|
||||
"X": "Y",
|
||||
"Y": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# A and B are scoped, M and everything after is NOT
|
||||
assert scopes == {"A": "C", "B": "C"}
|
||||
assert "M" not in scopes
|
||||
assert "X" not in scopes
|
||||
assert "Y" not in scopes
|
||||
|
||||
def test_branch_to_nonexistent_block(self):
|
||||
"""Test graceful handling when branch targets a non-existent block.
|
||||
|
||||
This shouldn't happen in practice (validation catches it), but the
|
||||
function should handle it gracefully.
|
||||
"""
|
||||
cond = _make_conditional_block("C", [("MISSING", False), ("A", True)])
|
||||
block_a = _make_task_block("A")
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": None,
|
||||
}
|
||||
|
||||
# Should not raise, MISSING just won't be in the results
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# Only A is scoped (MISSING is not in label_to_block)
|
||||
assert scopes == {"A": "C"}
|
||||
|
||||
def test_empty_workflow(self):
|
||||
"""Test with empty inputs."""
|
||||
scopes = compute_conditional_scopes({}, {})
|
||||
assert scopes == {}
|
||||
|
||||
def test_conditional_only_no_other_blocks(self):
|
||||
"""Test with only a conditional block and no branch targets.
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> None
|
||||
-> Branch2 -> None
|
||||
"""
|
||||
cond = _make_conditional_block("C", [(None, False), (None, True)])
|
||||
|
||||
label_to_block = {"C": cond}
|
||||
default_next_map = {"C": None}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
assert scopes == {}
|
||||
|
||||
def test_asymmetric_branch_lengths(self):
|
||||
"""Test branches with significantly different chain lengths.
|
||||
|
||||
Workflow:
|
||||
Conditional(C) -> Branch1 -> A -> B -> C2 -> D -> M
|
||||
-> Branch2 -> M
|
||||
|
||||
Branch1 has a long chain, Branch2 goes directly to M.
|
||||
M is the only block in both chains, so it's the merge point.
|
||||
"""
|
||||
block_a = _make_task_block("A", next_block_label="B")
|
||||
block_b = _make_task_block("B", next_block_label="C2")
|
||||
block_c2 = _make_task_block("C2", next_block_label="D")
|
||||
block_d = _make_task_block("D", next_block_label="M")
|
||||
block_m = _make_task_block("M")
|
||||
cond = _make_conditional_block("C", [("A", False), ("M", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C": cond,
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
"C2": block_c2,
|
||||
"D": block_d,
|
||||
"M": block_m,
|
||||
}
|
||||
default_next_map = {
|
||||
"C": None,
|
||||
"A": "B",
|
||||
"B": "C2",
|
||||
"C2": "D",
|
||||
"D": "M",
|
||||
"M": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# A, B, C2, D are in Branch1 only, so they're scoped
|
||||
# M appears in both branches, so it's the merge point
|
||||
assert scopes == {"A": "C", "B": "C", "C2": "C", "D": "C"}
|
||||
assert "M" not in scopes
|
||||
|
||||
def test_multiple_independent_conditionals(self):
|
||||
"""Test multiple conditionals at the same level (not nested).
|
||||
|
||||
Workflow:
|
||||
C1 -> Branch1 -> A
|
||||
-> Branch2 -> B
|
||||
(after C1) -> C2 -> Branch1 -> X
|
||||
-> Branch2 -> Y
|
||||
"""
|
||||
block_a = _make_task_block("A", next_block_label="C2")
|
||||
block_b = _make_task_block("B", next_block_label="C2")
|
||||
block_x = _make_task_block("X")
|
||||
block_y = _make_task_block("Y")
|
||||
cond1 = _make_conditional_block("C1", [("A", False), ("B", True)])
|
||||
cond2 = _make_conditional_block("C2", [("X", False), ("Y", True)])
|
||||
|
||||
label_to_block = {
|
||||
"C1": cond1,
|
||||
"C2": cond2,
|
||||
"A": block_a,
|
||||
"B": block_b,
|
||||
"X": block_x,
|
||||
"Y": block_y,
|
||||
}
|
||||
default_next_map = {
|
||||
"C1": None,
|
||||
"C2": None,
|
||||
"A": "C2",
|
||||
"B": "C2",
|
||||
"X": None,
|
||||
"Y": None,
|
||||
}
|
||||
|
||||
scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# A and B are scoped to C1
|
||||
# C2 is the merge point for C1 (appears in both A and B chains)
|
||||
# X and Y are scoped to C2
|
||||
assert scopes.get("A") == "C1"
|
||||
assert scopes.get("B") == "C1"
|
||||
assert "C2" not in scopes # C2 is a merge point for C1
|
||||
assert scopes.get("X") == "C2"
|
||||
assert scopes.get("Y") == "C2"
|
||||
811
tests/unit/test_conditional_script_caching.py
Normal file
811
tests/unit/test_conditional_script_caching.py
Normal file
@@ -0,0 +1,811 @@
|
||||
"""
|
||||
Tests for conditional block script caching support.
|
||||
|
||||
This test file verifies that:
|
||||
1. Workflows with conditional blocks can have scripts generated for cacheable blocks
|
||||
2. The regeneration logic doesn't trigger unnecessary regeneration for unexecuted branches
|
||||
3. Progressive caching works correctly across multiple runs
|
||||
4. Cached blocks from unexecuted branches are preserved during script regeneration (SKY-7815)
|
||||
|
||||
Key bugs this tests against:
|
||||
- Previously, the regeneration check compared cached blocks against ALL blocks in the workflow
|
||||
definition, causing "missing" blocks from unexecuted branches to trigger regeneration
|
||||
on EVERY run, flooding the database with redundant script operations.
|
||||
- (SKY-7815) When regeneration was triggered for a legitimate reason, cached blocks from
|
||||
unexecuted conditional branches were DROPPED because generate_workflow_script_python_code()
|
||||
only iterated blocks from the transform output (executed blocks). This caused a regeneration
|
||||
loop where blocks kept getting dropped and re-added.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.core.script_generations.generate_script import ScriptBlockSource, generate_workflow_script_python_code
|
||||
from skyvern.forge.sdk.workflow.service import BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
from skyvern.schemas.workflows import BlockType
|
||||
from skyvern.services.workflow_script_service import workflow_has_conditionals
|
||||
|
||||
|
||||
class TestConditionalBlockDetection:
|
||||
"""Tests for workflow_has_conditionals() function."""
|
||||
|
||||
def test_workflow_without_conditionals(self) -> None:
|
||||
"""Workflows without conditional blocks should return False."""
|
||||
|
||||
class MockBlock:
|
||||
def __init__(self, block_type: BlockType):
|
||||
self.block_type = block_type
|
||||
self.label = f"block_{block_type.value}"
|
||||
|
||||
class MockWorkflowDefinition:
|
||||
def __init__(self, blocks: list):
|
||||
self.blocks = blocks
|
||||
|
||||
class MockWorkflow:
|
||||
def __init__(self, blocks: list):
|
||||
self.workflow_definition = MockWorkflowDefinition(blocks)
|
||||
self.workflow_id = "test_workflow"
|
||||
|
||||
# Workflow with only navigation and extraction blocks
|
||||
blocks = [
|
||||
MockBlock(BlockType.NAVIGATION),
|
||||
MockBlock(BlockType.EXTRACTION),
|
||||
]
|
||||
workflow = MockWorkflow(blocks)
|
||||
|
||||
assert workflow_has_conditionals(workflow) is False
|
||||
|
||||
def test_workflow_with_conditionals(self) -> None:
|
||||
"""Workflows with conditional blocks should return True."""
|
||||
|
||||
class MockBlock:
|
||||
def __init__(self, block_type: BlockType):
|
||||
self.block_type = block_type
|
||||
self.label = f"block_{block_type.value}"
|
||||
|
||||
class MockWorkflowDefinition:
|
||||
def __init__(self, blocks: list):
|
||||
self.blocks = blocks
|
||||
|
||||
class MockWorkflow:
|
||||
def __init__(self, blocks: list):
|
||||
self.workflow_definition = MockWorkflowDefinition(blocks)
|
||||
self.workflow_id = "test_workflow"
|
||||
|
||||
# Workflow with a conditional block
|
||||
blocks = [
|
||||
MockBlock(BlockType.NAVIGATION),
|
||||
MockBlock(BlockType.CONDITIONAL),
|
||||
MockBlock(BlockType.EXTRACTION),
|
||||
]
|
||||
workflow = MockWorkflow(blocks)
|
||||
|
||||
assert workflow_has_conditionals(workflow) is True
|
||||
|
||||
|
||||
class TestConditionalBlockNotCached:
|
||||
"""Tests verifying conditional blocks are not in BLOCK_TYPES_THAT_SHOULD_BE_CACHED."""
|
||||
|
||||
def test_conditional_not_in_cached_types(self) -> None:
|
||||
"""Conditional blocks should NOT be in the set of cacheable block types."""
|
||||
assert BlockType.CONDITIONAL not in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
|
||||
def test_cacheable_types_exist(self) -> None:
|
||||
"""Verify that cacheable block types exist and include expected types."""
|
||||
assert BlockType.NAVIGATION in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
assert BlockType.EXTRACTION in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
assert BlockType.TASK in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
|
||||
|
||||
class TestRegenerationLogicForConditionals:
|
||||
"""
|
||||
Tests for the regeneration decision logic when conditionals are present.
|
||||
|
||||
The key fix: For workflows WITH conditionals, missing labels from unexecuted
|
||||
branches should NOT trigger regeneration. This prevents the database flooding
|
||||
bug where every run caused unnecessary script regeneration.
|
||||
"""
|
||||
|
||||
def test_missing_labels_computation(self) -> None:
|
||||
"""
|
||||
Test that the missing labels computation works correctly.
|
||||
|
||||
For a workflow with branches A and B:
|
||||
- should_cache_block_labels = {A, B, START}
|
||||
- cached_block_labels = {A, START} (only A executed)
|
||||
- missing_labels = {B}
|
||||
|
||||
Without the fix: missing_labels triggers regeneration every time
|
||||
With the fix: missing_labels is ignored for workflows with conditionals
|
||||
"""
|
||||
# Simulate the computation
|
||||
should_cache_block_labels = {"branch_a_extract", "branch_b_extract", "WORKFLOW_START_BLOCK"}
|
||||
cached_block_labels = {"branch_a_extract", "WORKFLOW_START_BLOCK"}
|
||||
|
||||
missing_labels = should_cache_block_labels - cached_block_labels
|
||||
assert missing_labels == {"branch_b_extract"}
|
||||
|
||||
# With conditionals, this should NOT trigger regeneration
|
||||
has_conditionals = True
|
||||
blocks_to_update: set[str] = set()
|
||||
|
||||
if missing_labels and not has_conditionals:
|
||||
blocks_to_update.update(missing_labels)
|
||||
|
||||
# blocks_to_update should be empty because we have conditionals
|
||||
assert len(blocks_to_update) == 0
|
||||
|
||||
def test_regeneration_triggered_without_conditionals(self) -> None:
|
||||
"""
|
||||
Without conditionals, missing labels SHOULD trigger regeneration.
|
||||
|
||||
This is the expected behavior for regular workflows where all blocks
|
||||
should eventually be cached.
|
||||
"""
|
||||
should_cache_block_labels = {"block_1", "block_2", "WORKFLOW_START_BLOCK"}
|
||||
cached_block_labels = {"block_1", "WORKFLOW_START_BLOCK"}
|
||||
|
||||
missing_labels = should_cache_block_labels - cached_block_labels
|
||||
assert missing_labels == {"block_2"}
|
||||
|
||||
# Without conditionals, this SHOULD trigger regeneration
|
||||
has_conditionals = False
|
||||
blocks_to_update: set[str] = set()
|
||||
|
||||
if missing_labels and not has_conditionals:
|
||||
blocks_to_update.update(missing_labels)
|
||||
|
||||
# blocks_to_update should contain missing labels
|
||||
assert "block_2" in blocks_to_update
|
||||
|
||||
def test_explicit_updates_still_work_with_conditionals(self) -> None:
|
||||
"""
|
||||
Even with conditionals, explicit blocks_to_update from the caller
|
||||
should still trigger regeneration.
|
||||
|
||||
This ensures that actual changes to executed blocks are still processed.
|
||||
"""
|
||||
blocks_to_update: set[str] = {"explicitly_updated_block"} # From caller
|
||||
|
||||
# Even with conditionals, explicit updates should trigger regeneration
|
||||
should_regenerate = bool(blocks_to_update)
|
||||
assert should_regenerate is True
|
||||
|
||||
|
||||
class TestProgressiveCachingConcept:
|
||||
"""
|
||||
Tests documenting the progressive caching concept for conditional workflows.
|
||||
|
||||
Progressive caching means:
|
||||
1. Run 1 takes branch A → caches blocks from A
|
||||
2. Run 2 takes branch B → caches blocks from B (preserves A's cache)
|
||||
3. Eventually all branches have cached blocks
|
||||
|
||||
The key insight is that we DON'T regenerate just because some branches
|
||||
haven't executed yet.
|
||||
"""
|
||||
|
||||
def test_progressive_caching_scenario(self) -> None:
|
||||
"""
|
||||
Simulate multiple runs with different branches.
|
||||
|
||||
Run 1: Branch A executes
|
||||
Run 2: Branch A executes (should NOT regenerate - same blocks)
|
||||
Run 3: Branch B executes (should cache B, preserve A)
|
||||
"""
|
||||
# Initial state
|
||||
cached_blocks: set[str] = set()
|
||||
|
||||
# Run 1: Branch A executes
|
||||
executed_blocks_run1 = {"nav_block", "branch_a_extract"}
|
||||
cached_blocks.update(executed_blocks_run1)
|
||||
assert cached_blocks == {"nav_block", "branch_a_extract"}
|
||||
|
||||
# Run 2: Branch A executes again
|
||||
executed_blocks_run2 = {"nav_block", "branch_a_extract"}
|
||||
# No new blocks to cache - should NOT trigger regeneration
|
||||
new_blocks_run2 = executed_blocks_run2 - cached_blocks
|
||||
assert len(new_blocks_run2) == 0 # Nothing new to cache
|
||||
|
||||
# Run 3: Branch B executes
|
||||
executed_blocks_run3 = {"nav_block", "branch_b_extract"}
|
||||
new_blocks_run3 = executed_blocks_run3 - cached_blocks
|
||||
assert new_blocks_run3 == {"branch_b_extract"} # New block to cache
|
||||
|
||||
# Cache should now have both branches
|
||||
cached_blocks.update(executed_blocks_run3)
|
||||
assert cached_blocks == {"nav_block", "branch_a_extract", "branch_b_extract"}
|
||||
|
||||
|
||||
class TestConditionalBlockCodeGeneration:
|
||||
"""Tests for conditional block handling in code generation."""
|
||||
|
||||
def test_conditional_block_type_string(self) -> None:
|
||||
"""Verify the conditional block type string matches expected value."""
|
||||
assert BlockType.CONDITIONAL.value == "conditional"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SKY-7815: Tests for cached block preservation during regeneration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCachedBlockPreservationDuringRegeneration:
|
||||
"""
|
||||
Tests verifying that cached blocks from unexecuted conditional branches
|
||||
are preserved when generate_workflow_script_python_code() regenerates a script.
|
||||
|
||||
Bug (SKY-7815):
|
||||
When a workflow has conditional branches A and B:
|
||||
- Run 1 executes branch A → script has blocks from A
|
||||
- Run 2 executes branch B → regeneration triggered → transform only returns B's blocks
|
||||
- generate_workflow_script_python_code() only iterates transform output (B's blocks)
|
||||
- Cached blocks from A are loaded into cached_blocks dict but NEVER iterated
|
||||
- Result: A's blocks are DROPPED from the new script → regeneration loop
|
||||
|
||||
Fix: After processing all blocks from the transform output, iterate remaining
|
||||
cached_blocks entries and preserve them in both the DB and script output.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_blocks_from_unexecuted_branch_are_preserved(self) -> None:
|
||||
"""
|
||||
Core test: when only branch B's blocks are in the transform output,
|
||||
branch A's cached blocks should still appear in the generated script.
|
||||
"""
|
||||
# Branch A's cached block (from a previous run)
|
||||
branch_a_code = (
|
||||
"async def branch_a_extract(page: SkyvernPage, context: RunContext) -> None:\n"
|
||||
" await skyvern.extract(page, \"//div[@id='result']\")\n"
|
||||
)
|
||||
cached_blocks = {
|
||||
"branch_a_extract": ScriptBlockSource(
|
||||
label="branch_a_extract",
|
||||
code=branch_a_code,
|
||||
run_signature="await branch_a_extract(page, context)",
|
||||
workflow_run_id="wr_run1",
|
||||
workflow_run_block_id="wfrb_a",
|
||||
input_fields=None,
|
||||
),
|
||||
}
|
||||
|
||||
# Transform output only has branch B's block (branch B executed this run)
|
||||
blocks = [
|
||||
{
|
||||
"block_type": "navigation",
|
||||
"label": "branch_b_navigate",
|
||||
"task_id": "task_b",
|
||||
"navigation_goal": "Go to page B",
|
||||
"url": "https://example.com/b",
|
||||
"workflow_run_id": "wr_run2",
|
||||
"workflow_run_block_id": "wfrb_b",
|
||||
},
|
||||
]
|
||||
|
||||
actions_by_task = {
|
||||
"task_b": [
|
||||
{
|
||||
"action_type": "click",
|
||||
"action_id": "action_b1",
|
||||
"xpath": "//button[@id='submit']",
|
||||
"element_id": "submit",
|
||||
"reasoning": "Click submit",
|
||||
"intention": "Submit the form",
|
||||
"confidence_float": 0.95,
|
||||
"has_mini_agent": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
workflow = {
|
||||
"workflow_id": "wf_test",
|
||||
"workflow_permanent_id": "wpid_test",
|
||||
"title": "Test Conditional Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{"parameter_type": "workflow", "key": "url", "default_value": "https://example.com"},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
workflow_run_request = {
|
||||
"workflow_id": "wpid_test",
|
||||
"parameters": {"url": "https://example.com"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("", {}),
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_block,
|
||||
):
|
||||
result = await generate_workflow_script_python_code(
|
||||
file_name="test.py",
|
||||
workflow_run_request=workflow_run_request,
|
||||
workflow=workflow,
|
||||
blocks=blocks,
|
||||
actions_by_task=actions_by_task,
|
||||
cached_blocks=cached_blocks,
|
||||
updated_block_labels={"branch_b_navigate", "__start_block__"},
|
||||
script_id="script_123",
|
||||
script_revision_id="rev_123",
|
||||
organization_id="org_123",
|
||||
)
|
||||
|
||||
# The output should contain branch A's cached code
|
||||
assert "branch_a_extract" in result, (
|
||||
"Cached block from unexecuted branch A should be preserved in the script output"
|
||||
)
|
||||
|
||||
# Verify create_or_update_script_block was called for the preserved block
|
||||
preserved_calls = [
|
||||
call
|
||||
for call in mock_create_block.call_args_list
|
||||
if call.kwargs.get("block_label") == "branch_a_extract"
|
||||
]
|
||||
assert len(preserved_calls) == 1, (
|
||||
"create_or_update_script_block should be called for the preserved cached block"
|
||||
)
|
||||
preserved_call = preserved_calls[0]
|
||||
assert preserved_call.kwargs["run_signature"] == "await branch_a_extract(page, context)"
|
||||
assert preserved_call.kwargs["workflow_run_id"] == "wr_run1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_blocks_without_run_signature_are_not_preserved(self) -> None:
|
||||
"""Cached blocks without a run_signature should NOT be preserved."""
|
||||
cached_blocks = {
|
||||
"incomplete_block": ScriptBlockSource(
|
||||
label="incomplete_block",
|
||||
code="async def incomplete_block(): pass\n",
|
||||
run_signature=None, # No run_signature
|
||||
workflow_run_id="wr_old",
|
||||
workflow_run_block_id="wfrb_old",
|
||||
input_fields=None,
|
||||
),
|
||||
}
|
||||
|
||||
blocks: list = []
|
||||
actions_by_task: dict = {}
|
||||
workflow = {
|
||||
"workflow_id": "wf_test",
|
||||
"title": "Test",
|
||||
"workflow_definition": {"parameters": []},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("", {}),
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_block,
|
||||
):
|
||||
result = await generate_workflow_script_python_code(
|
||||
file_name="test.py",
|
||||
workflow_run_request={"workflow_id": "wpid_test"},
|
||||
workflow=workflow,
|
||||
blocks=blocks,
|
||||
actions_by_task=actions_by_task,
|
||||
cached_blocks=cached_blocks,
|
||||
updated_block_labels={"__start_block__"},
|
||||
script_id="script_123",
|
||||
script_revision_id="rev_123",
|
||||
organization_id="org_123",
|
||||
)
|
||||
|
||||
# Incomplete block should NOT appear in the output
|
||||
assert "incomplete_block" not in result
|
||||
|
||||
# create_or_update_script_block should NOT be called for incomplete block
|
||||
incomplete_calls = [
|
||||
call
|
||||
for call in mock_create_block.call_args_list
|
||||
if call.kwargs.get("block_label") == "incomplete_block"
|
||||
]
|
||||
assert len(incomplete_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_blocks_without_code_are_not_preserved(self) -> None:
|
||||
"""Cached blocks without code should NOT be preserved."""
|
||||
cached_blocks = {
|
||||
"empty_block": ScriptBlockSource(
|
||||
label="empty_block",
|
||||
code="", # Empty code
|
||||
run_signature="await empty_block(page, context)",
|
||||
workflow_run_id="wr_old",
|
||||
workflow_run_block_id="wfrb_old",
|
||||
input_fields=None,
|
||||
),
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("", {}),
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_block,
|
||||
):
|
||||
await generate_workflow_script_python_code(
|
||||
file_name="test.py",
|
||||
workflow_run_request={"workflow_id": "wpid_test"},
|
||||
workflow={
|
||||
"workflow_id": "wf_test",
|
||||
"title": "Test",
|
||||
"workflow_definition": {"parameters": []},
|
||||
},
|
||||
blocks=[],
|
||||
actions_by_task={},
|
||||
cached_blocks=cached_blocks,
|
||||
updated_block_labels={"__start_block__"},
|
||||
script_id="script_123",
|
||||
script_revision_id="rev_123",
|
||||
organization_id="org_123",
|
||||
)
|
||||
|
||||
# Empty block should NOT appear
|
||||
empty_calls = [
|
||||
call for call in mock_create_block.call_args_list if call.kwargs.get("block_label") == "empty_block"
|
||||
]
|
||||
assert len(empty_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_processed_blocks_are_not_duplicated(self) -> None:
|
||||
"""
|
||||
Blocks that appear in both the transform output AND cached_blocks
|
||||
should NOT be duplicated. The transform output processing handles them.
|
||||
"""
|
||||
block_code = (
|
||||
"async def shared_block(page: SkyvernPage, context: RunContext) -> None:\n"
|
||||
' await skyvern.click(page, "//button")\n'
|
||||
)
|
||||
cached_blocks = {
|
||||
"shared_block": ScriptBlockSource(
|
||||
label="shared_block",
|
||||
code=block_code,
|
||||
run_signature="await shared_block(page, context)",
|
||||
workflow_run_id="wr_run1",
|
||||
workflow_run_block_id="wfrb_shared",
|
||||
input_fields=None,
|
||||
),
|
||||
}
|
||||
|
||||
# Same block also appears in the transform output (it executed this run too)
|
||||
blocks = [
|
||||
{
|
||||
"block_type": "navigation",
|
||||
"label": "shared_block",
|
||||
"task_id": "task_shared",
|
||||
"navigation_goal": "Navigate somewhere",
|
||||
"url": "https://example.com",
|
||||
"workflow_run_id": "wr_run2",
|
||||
"workflow_run_block_id": "wfrb_shared_run2",
|
||||
},
|
||||
]
|
||||
|
||||
actions_by_task = {
|
||||
"task_shared": [
|
||||
{
|
||||
"action_type": "click",
|
||||
"action_id": "action_1",
|
||||
"xpath": "//button",
|
||||
"element_id": "btn",
|
||||
"reasoning": "Click",
|
||||
"intention": "Click",
|
||||
"confidence_float": 0.9,
|
||||
"has_mini_agent": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("", {}),
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_block,
|
||||
):
|
||||
await generate_workflow_script_python_code(
|
||||
file_name="test.py",
|
||||
workflow_run_request={"workflow_id": "wpid_test"},
|
||||
workflow={
|
||||
"workflow_id": "wf_test",
|
||||
"title": "Test",
|
||||
"workflow_definition": {"parameters": []},
|
||||
},
|
||||
blocks=blocks,
|
||||
actions_by_task=actions_by_task,
|
||||
cached_blocks=cached_blocks,
|
||||
updated_block_labels={"shared_block", "__start_block__"},
|
||||
script_id="script_123",
|
||||
script_revision_id="rev_123",
|
||||
organization_id="org_123",
|
||||
)
|
||||
|
||||
# The block should appear exactly once (from the transform output processing,
|
||||
# NOT duplicated by the preservation loop)
|
||||
shared_calls = [
|
||||
call for call in mock_create_block.call_args_list if call.kwargs.get("block_label") == "shared_block"
|
||||
]
|
||||
# Should be called once from the normal task_v1 processing, NOT again from preservation
|
||||
assert len(shared_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_unexecuted_branches_all_preserved(self) -> None:
|
||||
"""
|
||||
When a workflow has 3 conditional branches and only 1 executes,
|
||||
cached blocks from the other 2 branches should ALL be preserved.
|
||||
"""
|
||||
|
||||
def _make_cached_block(label: str) -> ScriptBlockSource:
|
||||
return ScriptBlockSource(
|
||||
label=label,
|
||||
code=f"async def {label}(page: SkyvernPage, context: RunContext) -> None:\n pass\n",
|
||||
run_signature=f"await {label}(page, context)",
|
||||
workflow_run_id="wr_old",
|
||||
workflow_run_block_id=f"wfrb_{label}",
|
||||
input_fields=None,
|
||||
)
|
||||
|
||||
cached_blocks = {
|
||||
"branch_a_extract": _make_cached_block("branch_a_extract"),
|
||||
"branch_b_navigate": _make_cached_block("branch_b_navigate"),
|
||||
# branch_c executed this run, so it's also in blocks below
|
||||
}
|
||||
|
||||
# Only branch C's block is in the transform output
|
||||
blocks = [
|
||||
{
|
||||
"block_type": "extraction",
|
||||
"label": "branch_c_extract",
|
||||
"task_id": "task_c",
|
||||
"data_extraction_goal": "Extract C data",
|
||||
"workflow_run_id": "wr_run3",
|
||||
"workflow_run_block_id": "wfrb_c",
|
||||
},
|
||||
]
|
||||
|
||||
actions_by_task = {
|
||||
"task_c": [
|
||||
{
|
||||
"action_type": "extract",
|
||||
"action_id": "action_c1",
|
||||
"xpath": "//div[@class='data']",
|
||||
"element_id": "data",
|
||||
"reasoning": "Extract",
|
||||
"intention": "Extract data",
|
||||
"confidence_float": 0.9,
|
||||
"has_mini_agent": False,
|
||||
"data_extraction_goal": "Extract C data",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("", {}),
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_block,
|
||||
):
|
||||
result = await generate_workflow_script_python_code(
|
||||
file_name="test.py",
|
||||
workflow_run_request={"workflow_id": "wpid_test"},
|
||||
workflow={
|
||||
"workflow_id": "wf_test",
|
||||
"title": "Test",
|
||||
"workflow_definition": {"parameters": []},
|
||||
},
|
||||
blocks=blocks,
|
||||
actions_by_task=actions_by_task,
|
||||
cached_blocks=cached_blocks,
|
||||
updated_block_labels={"branch_c_extract", "__start_block__"},
|
||||
script_id="script_123",
|
||||
script_revision_id="rev_123",
|
||||
organization_id="org_123",
|
||||
)
|
||||
|
||||
# Both branch A and B should be preserved
|
||||
assert "branch_a_extract" in result, "Branch A cached block should be preserved"
|
||||
assert "branch_b_navigate" in result, "Branch B cached block should be preserved"
|
||||
assert "branch_c_extract" in result, "Branch C (executed) block should be present"
|
||||
|
||||
# Verify DB entries were created for all 3 blocks + __start_block__
|
||||
all_labels = {call.kwargs.get("block_label") for call in mock_create_block.call_args_list}
|
||||
assert "branch_a_extract" in all_labels
|
||||
assert "branch_b_navigate" in all_labels
|
||||
assert "branch_c_extract" in all_labels
|
||||
assert "__start_block__" in all_labels
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preservation_without_script_context(self) -> None:
|
||||
"""
|
||||
When script_id/script_revision_id/organization_id are not provided,
|
||||
cached blocks should still be added to the script output (just no DB calls).
|
||||
"""
|
||||
branch_a_code = "async def branch_a(page: SkyvernPage, context: RunContext) -> None:\n pass\n"
|
||||
cached_blocks = {
|
||||
"branch_a": ScriptBlockSource(
|
||||
label="branch_a",
|
||||
code=branch_a_code,
|
||||
run_signature="await branch_a(page, context)",
|
||||
workflow_run_id="wr_old",
|
||||
workflow_run_block_id="wfrb_a",
|
||||
input_fields=None,
|
||||
),
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.generate_workflow_parameters_schema",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("", {}),
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.generate_script.create_or_update_script_block",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_block,
|
||||
):
|
||||
result = await generate_workflow_script_python_code(
|
||||
file_name="test.py",
|
||||
workflow_run_request={"workflow_id": "wpid_test"},
|
||||
workflow={
|
||||
"workflow_id": "wf_test",
|
||||
"title": "Test",
|
||||
"workflow_definition": {"parameters": []},
|
||||
},
|
||||
blocks=[],
|
||||
actions_by_task={},
|
||||
cached_blocks=cached_blocks,
|
||||
updated_block_labels={"__start_block__"},
|
||||
# No script context
|
||||
script_id=None,
|
||||
script_revision_id=None,
|
||||
organization_id=None,
|
||||
)
|
||||
|
||||
# Code should still be in the output
|
||||
assert "branch_a" in result
|
||||
|
||||
# But no DB calls should be made for preserved blocks
|
||||
preserved_calls = [
|
||||
call for call in mock_create_block.call_args_list if call.kwargs.get("block_label") == "branch_a"
|
||||
]
|
||||
assert len(preserved_calls) == 0
|
||||
|
||||
|
||||
class TestRegenerationLoopPrevention:
|
||||
"""
|
||||
End-to-end tests for the regeneration loop prevention (SKY-7815).
|
||||
|
||||
The regeneration loop happens when:
|
||||
1. Workflow has conditional branches A and B
|
||||
2. Run 1 caches branch A → script has A's blocks
|
||||
3. Run 2 executes branch B → triggers regeneration for B
|
||||
4. During regeneration, transform only returns B's blocks
|
||||
5. A's cached blocks are dropped from the new script
|
||||
6. Run 3 executes branch A → A is "missing" → triggers regeneration
|
||||
7. During regeneration, B's cached blocks are dropped → loop continues
|
||||
|
||||
The fix has two parts:
|
||||
1. generate_script_if_needed: Don't add missing labels for conditional workflows
|
||||
2. generate_workflow_script_python_code: Preserve cached blocks from unexecuted branches
|
||||
"""
|
||||
|
||||
def test_regeneration_loop_scenario_is_prevented(self) -> None:
|
||||
"""
|
||||
Simulate the full regeneration loop scenario and verify it's prevented.
|
||||
|
||||
This test verifies both parts of the fix working together:
|
||||
- Missing labels don't trigger regeneration for conditional workflows
|
||||
- Even if regeneration IS triggered (for other reasons), cached blocks are preserved
|
||||
"""
|
||||
# --- Part 1: generate_script_if_needed logic ---
|
||||
# Workflow definition has blocks: nav, branch_a_extract, branch_b_extract
|
||||
should_cache_block_labels = {"nav_block", "branch_a_extract", "branch_b_extract", "__start_block__"}
|
||||
|
||||
# After Run 1: only nav and branch_a are cached
|
||||
cached_block_labels = {"nav_block", "branch_a_extract", "__start_block__"}
|
||||
|
||||
missing_labels = should_cache_block_labels - cached_block_labels
|
||||
assert missing_labels == {"branch_b_extract"}
|
||||
|
||||
has_conditionals = True
|
||||
blocks_to_update: set[str] = set()
|
||||
|
||||
# With conditionals, missing labels should NOT be added
|
||||
if missing_labels and not has_conditionals:
|
||||
blocks_to_update.update(missing_labels)
|
||||
elif missing_labels and has_conditionals:
|
||||
pass # Skip - expected for conditional workflows
|
||||
|
||||
# No regeneration needed just because of missing labels
|
||||
assert len(blocks_to_update) == 0
|
||||
|
||||
# --- Part 2: Even if regeneration IS triggered ---
|
||||
# e.g., branch B executed this run and needs caching
|
||||
blocks_to_update.add("branch_b_extract")
|
||||
|
||||
# The transform output only has branch B's block
|
||||
transform_output_labels = {"nav_block", "branch_b_extract"}
|
||||
|
||||
# cached_blocks from old script has branch A's data
|
||||
old_cached_block_labels = {"nav_block", "branch_a_extract"}
|
||||
|
||||
# After the fix, the preservation loop handles blocks NOT in transform output
|
||||
processed_by_transform = transform_output_labels
|
||||
preserved_from_cache = old_cached_block_labels - processed_by_transform
|
||||
|
||||
assert preserved_from_cache == {"branch_a_extract"}, (
|
||||
"Branch A's block should be preserved even though it wasn't in the transform output"
|
||||
)
|
||||
|
||||
# Final result should have ALL blocks
|
||||
final_blocks = transform_output_labels | preserved_from_cache
|
||||
assert final_blocks == {"nav_block", "branch_a_extract", "branch_b_extract"}
|
||||
|
||||
def test_no_regeneration_loop_across_three_runs(self) -> None:
|
||||
"""
|
||||
Simulate 3 runs and verify no regeneration loop occurs.
|
||||
|
||||
Run 1: Branch A → cache A
|
||||
Run 2: Branch B → regenerate (B is new) → A is preserved
|
||||
Run 3: Branch A → no regeneration needed (A is still cached)
|
||||
"""
|
||||
# --- Run 1: Branch A executes ---
|
||||
cached_blocks_after_run1 = {"nav_block", "branch_a_extract", "__start_block__"}
|
||||
|
||||
# --- Run 2: Branch B executes ---
|
||||
has_conditionals = True
|
||||
should_cache = {"nav_block", "branch_a_extract", "branch_b_extract", "__start_block__"}
|
||||
|
||||
missing_run2 = should_cache - cached_blocks_after_run1
|
||||
assert missing_run2 == {"branch_b_extract"}
|
||||
|
||||
blocks_to_update_run2: set[str] = set()
|
||||
# Missing labels NOT added for conditional workflows
|
||||
if missing_run2 and not has_conditionals:
|
||||
blocks_to_update_run2.update(missing_run2)
|
||||
|
||||
# branch_b_extract is added because it actually executed
|
||||
blocks_to_update_run2.add("branch_b_extract")
|
||||
|
||||
# Regeneration happens, but branch A is PRESERVED
|
||||
transform_output_run2 = {"nav_block", "branch_b_extract"}
|
||||
preserved_run2 = {"branch_a_extract"} # From cache, not in transform
|
||||
|
||||
cached_blocks_after_run2 = transform_output_run2 | preserved_run2 | {"__start_block__"}
|
||||
assert cached_blocks_after_run2 == {"nav_block", "branch_a_extract", "branch_b_extract", "__start_block__"}
|
||||
|
||||
# --- Run 3: Branch A executes again ---
|
||||
missing_run3 = should_cache - cached_blocks_after_run2
|
||||
assert len(missing_run3) == 0, "No missing blocks after Run 2 because branch A was preserved"
|
||||
|
||||
blocks_to_update_run3: set[str] = set()
|
||||
if missing_run3 and not has_conditionals:
|
||||
blocks_to_update_run3.update(missing_run3)
|
||||
|
||||
# branch_a_extract already has cached code, so it's NOT added to blocks_to_update
|
||||
# (execution tracking only adds blocks that DON'T have cached code)
|
||||
should_regenerate_run3 = bool(blocks_to_update_run3)
|
||||
assert should_regenerate_run3 is False, "No regeneration needed on Run 3 - the loop is broken"
|
||||
75
tests/unit/test_custom_credential_client.py
Normal file
75
tests/unit/test_custom_credential_client.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.api.custom_credential_client import CustomCredentialAPIClient
|
||||
from skyvern.forge.sdk.schemas.credentials import CredentialType, SecretCredential
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> CustomCredentialAPIClient:
|
||||
return CustomCredentialAPIClient(api_base_url="https://custom.example.com", api_token="token-123")
|
||||
|
||||
|
||||
def test_credential_to_api_payload_with_label(client: CustomCredentialAPIClient) -> None:
|
||||
credential = SecretCredential(secret_value="sk-secret", secret_label="api-key")
|
||||
|
||||
payload = client._credential_to_api_payload(credential)
|
||||
|
||||
assert payload == {
|
||||
"type": "secret",
|
||||
"secret_value": "sk-secret",
|
||||
"secret_label": "api-key",
|
||||
}
|
||||
|
||||
|
||||
def test_credential_to_api_payload_without_label(client: CustomCredentialAPIClient) -> None:
|
||||
credential = SecretCredential(secret_value="sk-secret-no-label")
|
||||
|
||||
payload = client._credential_to_api_payload(credential)
|
||||
|
||||
assert payload == {
|
||||
"type": "secret",
|
||||
"secret_value": "sk-secret-no-label",
|
||||
}
|
||||
|
||||
|
||||
def test_api_response_to_credential_secret_with_label(client: CustomCredentialAPIClient) -> None:
|
||||
response = {
|
||||
"type": "secret",
|
||||
"secret_value": "shhh",
|
||||
"secret_label": "prod-api",
|
||||
}
|
||||
|
||||
credential_item = client._api_response_to_credential(response, name="Prod API", item_id="cred_123")
|
||||
|
||||
assert credential_item.item_id == "cred_123"
|
||||
assert credential_item.name == "Prod API"
|
||||
assert credential_item.credential_type == CredentialType.SECRET
|
||||
assert isinstance(credential_item.credential, SecretCredential)
|
||||
assert credential_item.credential.secret_value == "shhh"
|
||||
assert credential_item.credential.secret_label == "prod-api"
|
||||
|
||||
|
||||
def test_api_response_to_credential_secret_without_label(client: CustomCredentialAPIClient) -> None:
|
||||
response = {
|
||||
"type": "secret",
|
||||
"secret_value": "token-only",
|
||||
}
|
||||
|
||||
credential_item = client._api_response_to_credential(response, name="Token", item_id="cred_456")
|
||||
|
||||
assert credential_item.item_id == "cred_456"
|
||||
assert credential_item.name == "Token"
|
||||
assert credential_item.credential_type == CredentialType.SECRET
|
||||
assert isinstance(credential_item.credential, SecretCredential)
|
||||
assert credential_item.credential.secret_value == "token-only"
|
||||
assert credential_item.credential.secret_label is None
|
||||
|
||||
|
||||
def test_api_response_to_credential_secret_missing_required_field(client: CustomCredentialAPIClient) -> None:
|
||||
response = {
|
||||
"type": "secret",
|
||||
"secret_label": "no-secret-value",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required secret fields from API"):
|
||||
client._api_response_to_credential(response, name="Broken Secret", item_id="cred_789")
|
||||
391
tests/unit/test_download_file_action_handler.py
Normal file
391
tests/unit/test_download_file_action_handler.py
Normal file
@@ -0,0 +1,391 @@
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.models import StepStatus
|
||||
from skyvern.webeye.actions.actions import DownloadFileAction
|
||||
from skyvern.webeye.actions.handler import handle_download_file_action
|
||||
from skyvern.webeye.actions.responses import ActionFailure, ActionSuccess
|
||||
from skyvern.webeye.scraper.scraped_page import ScrapedPage
|
||||
from tests.unit.helpers import make_organization, make_step, make_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_with_byte_data() -> None:
|
||||
"""Test that when byte data is provided, the file should be saved directly"""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
# Create mock objects
|
||||
page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
# Create test byte data
|
||||
test_bytes = b"test file content"
|
||||
action = DownloadFileAction(
|
||||
file_name="test_file.txt",
|
||||
byte=test_bytes,
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
# Mock initialize_download_dir to return a temporary directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=temp_dir):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
# Verify result (download_triggered is set by outer handle action flow when in context)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionSuccess)
|
||||
|
||||
# Verify file was created
|
||||
expected_file_path = os.path.join(temp_dir, "test_file.txt")
|
||||
assert os.path.exists(expected_file_path)
|
||||
|
||||
# Verify file content
|
||||
with open(expected_file_path, "rb") as f:
|
||||
assert f.read() == test_bytes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_with_download_url() -> None:
|
||||
"""Test that when download_url is provided, page.goto is called and returns ActionSuccess"""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
# Create mock objects
|
||||
page = MagicMock()
|
||||
page.goto = AsyncMock(return_value=None)
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
action = DownloadFileAction(
|
||||
file_name="downloaded_file.pdf",
|
||||
download_url="https://example.com/file.pdf",
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value="/tmp"):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
# Verify page.goto was called with the correct URL (handler uses browser navigation for download_url)
|
||||
page.goto.assert_called_once()
|
||||
assert page.goto.call_args[0][0] == "https://example.com/file.pdf"
|
||||
|
||||
# Verify result
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionSuccess)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_with_download_url_same_filename() -> None:
|
||||
"""Test that when download_url is provided, page.goto is called with the URL and returns ActionSuccess"""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
# Create mock objects
|
||||
page = MagicMock()
|
||||
page.goto = AsyncMock(return_value=None)
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
action = DownloadFileAction(
|
||||
file_name="same_name.pdf",
|
||||
download_url="https://example.com/file.pdf",
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value="/tmp"):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
page.goto.assert_called_once()
|
||||
assert page.goto.call_args[0][0] == "https://example.com/file.pdf"
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionSuccess)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_without_byte_or_url() -> None:
|
||||
"""Test that when neither byte data nor download_url is provided, should return ActionSuccess (no download triggered)."""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
# Create mock objects
|
||||
page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
action = DownloadFileAction(
|
||||
file_name="test_file.txt",
|
||||
byte=None,
|
||||
download_url=None,
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=temp_dir):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
# Verify result (download_triggered is set by outer handle action flow when in context)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionSuccess)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_with_byte_priority() -> None:
|
||||
"""Test that when both byte and download_url are provided, byte data should take priority"""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
# Create mock objects
|
||||
page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
# Create test byte data
|
||||
test_bytes = b"byte data content"
|
||||
action = DownloadFileAction(
|
||||
file_name="test_file.txt",
|
||||
byte=test_bytes,
|
||||
download_url="https://example.com/file.pdf",
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
page.goto = AsyncMock(return_value=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=temp_dir):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
# Byte data takes priority: page.goto should not be called
|
||||
page.goto.assert_not_called()
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionSuccess)
|
||||
|
||||
expected_file_path = os.path.join(temp_dir, "test_file.txt")
|
||||
assert os.path.exists(expected_file_path)
|
||||
with open(expected_file_path, "rb") as f:
|
||||
assert f.read() == test_bytes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_with_file_name_empty() -> None:
|
||||
"""Test that when file_name is empty string, UUID should be used as filename"""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
# Create mock objects
|
||||
page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
test_bytes = b"test content"
|
||||
action = DownloadFileAction(
|
||||
file_name="", # Empty string, handler will use UUID
|
||||
byte=test_bytes,
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=temp_dir):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
# Verify result (download_triggered is set by outer handle action flow when in context)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionSuccess)
|
||||
|
||||
# Verify file was created (filename should be UUID)
|
||||
files = os.listdir(temp_dir)
|
||||
assert len(files) == 1
|
||||
# Verify file content
|
||||
file_path = os.path.join(temp_dir, files[0])
|
||||
with open(file_path, "rb") as f:
|
||||
assert f.read() == test_bytes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_download_url_error() -> None:
|
||||
"""Test that when download_url download fails, should return ActionFailure"""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
# Create mock objects
|
||||
page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
action = DownloadFileAction(
|
||||
file_name="test_file.txt",
|
||||
download_url="https://example.com/file.pdf",
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
page.goto = AsyncMock(side_effect=Exception("Download failed"))
|
||||
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value="/tmp"):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionFailure)
|
||||
assert result[0].exception_type == "Exception"
|
||||
assert result[0].exception_message == "Download failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_file_write_error() -> None:
|
||||
"""Test that when file write fails, should return ActionFailure"""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
# Create mock objects
|
||||
page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
test_bytes = b"test content"
|
||||
action = DownloadFileAction(
|
||||
file_name="test_file.txt",
|
||||
byte=test_bytes,
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
# Mock initialize_download_dir to return an invalid path (e.g., read-only directory)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create a read-only directory to simulate write failure
|
||||
read_only_dir = os.path.join(temp_dir, "readonly")
|
||||
os.makedirs(read_only_dir, mode=0o555)
|
||||
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value=read_only_dir):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
# Verify result should be ActionFailure
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionFailure)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_download_file_action_download_url_err_aborted_swallowed() -> None:
|
||||
"""Test that when page.goto raises net::ERR_ABORTED (browser download flow), error is swallowed and returns ActionSuccess"""
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.created, order=0, output=None)
|
||||
|
||||
page = MagicMock()
|
||||
page.goto = AsyncMock(side_effect=Exception("net::ERR_ABORTED at https://example.com/file.pdf"))
|
||||
browser_state = MagicMock()
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=AsyncMock(return_value=[]),
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
|
||||
action = DownloadFileAction(
|
||||
file_name="test_file.txt",
|
||||
download_url="https://example.com/file.pdf",
|
||||
organization_id=task.organization_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
with patch("skyvern.webeye.actions.handler.initialize_download_dir", return_value="/tmp"):
|
||||
result = await handle_download_file_action(action, page, scraped_page, task, step)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ActionSuccess)
|
||||
183
tests/unit/test_finally_block_dag.py
Normal file
183
tests/unit/test_finally_block_dag.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Tests for DAG validation when blocks reference the finally block.
|
||||
|
||||
The finally block is excluded from the DAG before validation. Any block whose
|
||||
next_block_label points to the finally block must have that edge nullified so
|
||||
_build_workflow_graph does not raise InvalidWorkflowDefinition for a missing label.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.workflow.exceptions import InvalidWorkflowDefinition
|
||||
from skyvern.forge.sdk.workflow.models.block import (
|
||||
BranchCondition,
|
||||
ConditionalBlock,
|
||||
HttpRequestBlock,
|
||||
TaskBlock,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
|
||||
from skyvern.forge.sdk.workflow.service import WorkflowService
|
||||
|
||||
|
||||
def _make_output_parameter(key: str) -> OutputParameter:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
return OutputParameter(
|
||||
key=key,
|
||||
parameter_type="output",
|
||||
output_parameter_id=f"op_{key}",
|
||||
workflow_id="wf_test",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _make_task_block(label: str, *, next_block_label: str | None = None) -> TaskBlock:
|
||||
return TaskBlock(
|
||||
label=label,
|
||||
url="https://example.com",
|
||||
output_parameter=_make_output_parameter(label),
|
||||
next_block_label=next_block_label,
|
||||
)
|
||||
|
||||
|
||||
def _make_http_block(label: str, *, next_block_label: str | None = None) -> HttpRequestBlock:
|
||||
return HttpRequestBlock(
|
||||
label=label,
|
||||
url="https://example.com",
|
||||
method="GET",
|
||||
output_parameter=_make_output_parameter(label),
|
||||
next_block_label=next_block_label,
|
||||
)
|
||||
|
||||
|
||||
class TestStripFinallyBlockReferences:
|
||||
"""Tests for WorkflowService._strip_finally_block_references."""
|
||||
|
||||
def test_removes_finally_block_and_nullifies_edge(self):
|
||||
block_1 = _make_task_block("block_1", next_block_label="block_2")
|
||||
block_2 = _make_task_block("block_2", next_block_label="finally_block")
|
||||
finally_block = _make_http_block("finally_block")
|
||||
|
||||
result = WorkflowService._strip_finally_block_references(
|
||||
[block_1, block_2, finally_block],
|
||||
"finally_block",
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
labels = [b.label for b in result]
|
||||
assert "finally_block" not in labels
|
||||
# block_2 should have its edge to finally_block nullified
|
||||
assert result[1].label == "block_2"
|
||||
assert result[1].next_block_label is None
|
||||
|
||||
def test_conditional_branch_pointing_to_finally_is_nullified(self):
|
||||
block_1 = _make_task_block("block_1")
|
||||
cond_block = ConditionalBlock(
|
||||
label="cond_block",
|
||||
output_parameter=_make_output_parameter("cond_block"),
|
||||
branch_conditions=[
|
||||
BranchCondition(next_block_label="block_1", is_default=True),
|
||||
BranchCondition(
|
||||
next_block_label="finally_block",
|
||||
criteria={"criteria_type": "jinja2_template", "expression": "{{ true }}"},
|
||||
),
|
||||
],
|
||||
)
|
||||
finally_block = _make_http_block("finally_block")
|
||||
|
||||
result = WorkflowService._strip_finally_block_references(
|
||||
[block_1, cond_block, finally_block],
|
||||
"finally_block",
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
cond = next(b for b in result if b.label == "cond_block")
|
||||
for branch in cond.branch_conditions:
|
||||
assert branch.next_block_label != "finally_block", (
|
||||
"Branch pointing to finally_block should have been nullified"
|
||||
)
|
||||
|
||||
def test_noop_when_no_finally_block(self):
|
||||
block_1 = _make_task_block("block_1", next_block_label="block_2")
|
||||
block_2 = _make_task_block("block_2")
|
||||
|
||||
result = WorkflowService._strip_finally_block_references(
|
||||
[block_1, block_2],
|
||||
"nonexistent_finally",
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].next_block_label == "block_2"
|
||||
|
||||
|
||||
class TestBuildWorkflowGraphWithFinallyBlock:
|
||||
"""Tests that _build_workflow_graph succeeds after stripping finally block references."""
|
||||
|
||||
def test_dag_validation_with_block_pointing_to_finally_block(self):
|
||||
block_1 = _make_task_block("block_1", next_block_label="block_2")
|
||||
block_2 = _make_task_block("block_2", next_block_label="finally_block")
|
||||
finally_block = _make_http_block("finally_block")
|
||||
|
||||
dag_blocks = WorkflowService._strip_finally_block_references(
|
||||
[block_1, block_2, finally_block],
|
||||
"finally_block",
|
||||
)
|
||||
|
||||
svc = WorkflowService()
|
||||
start_label, label_to_block, default_next_map = svc._build_workflow_graph(dag_blocks)
|
||||
|
||||
assert start_label == "block_1"
|
||||
assert set(label_to_block.keys()) == {"block_1", "block_2"}
|
||||
assert default_next_map["block_1"] == "block_2"
|
||||
assert default_next_map["block_2"] is None
|
||||
|
||||
def test_dag_validation_with_conditional_block_branch_pointing_to_finally(self):
|
||||
block_1 = _make_task_block("block_1")
|
||||
cond_block = ConditionalBlock(
|
||||
label="cond_block",
|
||||
output_parameter=_make_output_parameter("cond_block"),
|
||||
branch_conditions=[
|
||||
BranchCondition(next_block_label="block_1", is_default=True),
|
||||
BranchCondition(
|
||||
next_block_label="finally_block",
|
||||
criteria={"criteria_type": "jinja2_template", "expression": "{{ true }}"},
|
||||
),
|
||||
],
|
||||
)
|
||||
finally_block = _make_http_block("finally_block")
|
||||
|
||||
dag_blocks = WorkflowService._strip_finally_block_references(
|
||||
[cond_block, block_1, finally_block],
|
||||
"finally_block",
|
||||
)
|
||||
|
||||
svc = WorkflowService()
|
||||
start_label, label_to_block, default_next_map = svc._build_workflow_graph(dag_blocks)
|
||||
|
||||
assert start_label == "cond_block"
|
||||
assert set(label_to_block.keys()) == {"cond_block", "block_1"}
|
||||
|
||||
def test_dag_validation_without_finally_block(self):
|
||||
block_1 = _make_task_block("block_1", next_block_label="block_2")
|
||||
block_2 = _make_task_block("block_2")
|
||||
|
||||
svc = WorkflowService()
|
||||
start_label, label_to_block, default_next_map = svc._build_workflow_graph([block_1, block_2])
|
||||
|
||||
assert start_label == "block_1"
|
||||
assert set(label_to_block.keys()) == {"block_1", "block_2"}
|
||||
assert default_next_map["block_1"] == "block_2"
|
||||
|
||||
def test_dag_validation_fails_without_stripping_finally_block(self):
|
||||
"""Without stripping, a block referencing the removed finally block causes an error."""
|
||||
block_1 = _make_task_block("block_1", next_block_label="block_2")
|
||||
block_2 = _make_task_block("block_2", next_block_label="finally_block")
|
||||
# Manually exclude the finally block but do NOT nullify the edge
|
||||
dag_blocks = [block_1, block_2]
|
||||
|
||||
svc = WorkflowService()
|
||||
with pytest.raises(InvalidWorkflowDefinition, match="unknown next_block_label"):
|
||||
svc._build_workflow_graph(dag_blocks)
|
||||
412
tests/unit/test_forloop_script_generation.py
Normal file
412
tests/unit/test_forloop_script_generation.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Unit tests for ForLoop block support in cached scripts (SKY-7751).
|
||||
|
||||
These tests verify that ForLoop blocks are properly handled during:
|
||||
1. Workflow transformation (transform_workflow_run.py)
|
||||
2. Script generation (generate_script.py)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import libcst as cst
|
||||
import pytest
|
||||
|
||||
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
|
||||
from skyvern.core.script_generations.generate_script import _build_for_loop_statement
|
||||
from skyvern.core.script_generations.transform_workflow_run import (
|
||||
CodeGenInput,
|
||||
transform_workflow_run_to_code_gen_input,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.service import BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
from skyvern.schemas.workflows import BlockType
|
||||
|
||||
|
||||
class TestForLoopInCacheableBlocks:
|
||||
"""Test that ForLoop is included in cacheable block types."""
|
||||
|
||||
def test_forloop_in_block_types_that_should_be_cached(self) -> None:
|
||||
"""Verify ForLoop is included in BLOCK_TYPES_THAT_SHOULD_BE_CACHED."""
|
||||
assert BlockType.FOR_LOOP in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
|
||||
|
||||
class TestForLoopTransformation:
|
||||
"""Test the transformation of ForLoop blocks during script generation."""
|
||||
|
||||
def test_forloop_child_blocks_identified_by_parent_id(self) -> None:
|
||||
"""Test that child blocks inside ForLoop can be identified by parent_workflow_run_block_id."""
|
||||
# Mock workflow run blocks
|
||||
forloop_block = MagicMock()
|
||||
forloop_block.workflow_run_block_id = "wfrb_forloop_123"
|
||||
forloop_block.parent_workflow_run_block_id = None
|
||||
forloop_block.block_type = BlockType.FOR_LOOP
|
||||
forloop_block.label = "process_urls"
|
||||
forloop_block.task_id = None
|
||||
|
||||
child_task_block = MagicMock()
|
||||
child_task_block.workflow_run_block_id = "wfrb_child_456"
|
||||
child_task_block.parent_workflow_run_block_id = "wfrb_forloop_123" # Points to ForLoop
|
||||
child_task_block.block_type = "task"
|
||||
child_task_block.label = "extract_data"
|
||||
child_task_block.task_id = "task_789"
|
||||
child_task_block.status = "completed"
|
||||
child_task_block.output = {"extracted": "data"}
|
||||
|
||||
all_blocks = [forloop_block, child_task_block]
|
||||
|
||||
# Filter child blocks by parent_workflow_run_block_id
|
||||
child_blocks = [b for b in all_blocks if b.parent_workflow_run_block_id == forloop_block.workflow_run_block_id]
|
||||
|
||||
assert len(child_blocks) == 1
|
||||
assert child_blocks[0].label == "extract_data"
|
||||
assert child_blocks[0].task_id == "task_789"
|
||||
|
||||
def test_child_run_blocks_by_label_mapping(self) -> None:
|
||||
"""Test creation of child run blocks mapping by label."""
|
||||
child_block_1 = MagicMock()
|
||||
child_block_1.label = "extract_data"
|
||||
child_block_1.block_type = "extraction"
|
||||
child_block_1.task_id = "task_1"
|
||||
|
||||
child_block_2 = MagicMock()
|
||||
child_block_2.label = "navigate_page"
|
||||
child_block_2.block_type = "navigation"
|
||||
child_block_2.task_id = "task_2"
|
||||
|
||||
child_run_blocks = [child_block_1, child_block_2]
|
||||
|
||||
# Create mapping by label
|
||||
child_run_blocks_by_label = {b.label: b for b in child_run_blocks if b.label}
|
||||
|
||||
assert "extract_data" in child_run_blocks_by_label
|
||||
assert "navigate_page" in child_run_blocks_by_label
|
||||
assert child_run_blocks_by_label["extract_data"].task_id == "task_1"
|
||||
|
||||
def test_forloop_definition_block_has_loop_blocks(self) -> None:
|
||||
"""Test that ForLoop definition block contains loop_blocks field."""
|
||||
forloop_definition = {
|
||||
"block_type": BlockType.FOR_LOOP,
|
||||
"label": "process_urls",
|
||||
"loop_variable_reference": "{{ urls }}",
|
||||
"loop_blocks": [
|
||||
{
|
||||
"block_type": "extraction",
|
||||
"label": "extract_data",
|
||||
"data_extraction_goal": "Extract page content",
|
||||
},
|
||||
{
|
||||
"block_type": "navigation",
|
||||
"label": "navigate_next",
|
||||
"navigation_goal": "Go to next page",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
loop_blocks = forloop_definition.get("loop_blocks", [])
|
||||
|
||||
assert len(loop_blocks) == 2
|
||||
assert loop_blocks[0]["label"] == "extract_data"
|
||||
assert loop_blocks[1]["label"] == "navigate_next"
|
||||
|
||||
|
||||
class TestForLoopScriptGeneration:
|
||||
"""Test script code generation for ForLoop blocks."""
|
||||
|
||||
def test_build_for_loop_statement_signature(self) -> None:
|
||||
"""Test that _build_for_loop_statement is called with correct parameters."""
|
||||
forloop_block = {
|
||||
"block_type": "for_loop",
|
||||
"label": "process_items",
|
||||
"loop_variable_reference": "{{ items }}",
|
||||
"loop_blocks": [
|
||||
{
|
||||
"block_type": "extraction",
|
||||
"label": "extract_item",
|
||||
"data_extraction_goal": "Extract item details",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# This should not raise an error
|
||||
result = _build_for_loop_statement("process_items", forloop_block)
|
||||
|
||||
# The result should be a CST For node
|
||||
assert result is not None
|
||||
assert hasattr(result, "target") # For loop has a target
|
||||
assert hasattr(result, "iter") # For loop has an iterator
|
||||
assert hasattr(result, "body") # For loop has a body
|
||||
|
||||
|
||||
class TestForLoopChildBlockActions:
|
||||
"""Test that actions from child blocks inside ForLoop are collected."""
|
||||
|
||||
def test_task_block_in_forloop_should_collect_actions(self) -> None:
|
||||
"""Test that task blocks inside ForLoop have their actions collected."""
|
||||
# This tests the logic added in transform_workflow_run.py
|
||||
child_run_block = MagicMock()
|
||||
child_run_block.block_type = "task"
|
||||
child_run_block.task_id = "task_123"
|
||||
child_run_block.label = "search_item"
|
||||
|
||||
# Verify that the child block type is in SCRIPT_TASK_BLOCKS
|
||||
assert child_run_block.block_type in SCRIPT_TASK_BLOCKS
|
||||
|
||||
# Verify that task_id is present (required for action collection)
|
||||
assert child_run_block.task_id is not None
|
||||
|
||||
def test_non_task_block_in_forloop_does_not_collect_actions(self) -> None:
|
||||
"""Test that non-task blocks inside ForLoop don't collect actions."""
|
||||
child_run_block = MagicMock()
|
||||
child_run_block.block_type = "goto_url"
|
||||
child_run_block.task_id = None
|
||||
child_run_block.label = "go_to_url"
|
||||
|
||||
# Verify that goto_url is not in SCRIPT_TASK_BLOCKS
|
||||
assert child_run_block.block_type not in SCRIPT_TASK_BLOCKS
|
||||
|
||||
|
||||
class TestForLoopActionsHydration:
|
||||
"""Test that actions from ForLoop child blocks are properly hydrated."""
|
||||
|
||||
def test_actions_by_task_includes_forloop_child_actions(self) -> None:
|
||||
"""Test that actions_by_task dict includes actions from ForLoop child blocks."""
|
||||
actions_by_task: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
# Simulate adding actions from a child block inside ForLoop
|
||||
child_task_id = "task_in_forloop_123"
|
||||
child_actions = [
|
||||
{
|
||||
"action_type": "input_text",
|
||||
"action_id": "action_1",
|
||||
"text": "search query",
|
||||
"xpath": "//input[@id='search']",
|
||||
},
|
||||
{
|
||||
"action_type": "click",
|
||||
"action_id": "action_2",
|
||||
"xpath": "//button[@type='submit']",
|
||||
},
|
||||
]
|
||||
|
||||
actions_by_task[child_task_id] = child_actions
|
||||
|
||||
# Verify actions are stored
|
||||
assert child_task_id in actions_by_task
|
||||
assert len(actions_by_task[child_task_id]) == 2
|
||||
assert actions_by_task[child_task_id][0]["action_type"] == "input_text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transform_forloop_block_integration() -> None:
|
||||
"""
|
||||
Integration test for ForLoop block transformation.
|
||||
|
||||
This test mocks the database calls and verifies that the transformation
|
||||
correctly processes ForLoop blocks and their child blocks.
|
||||
"""
|
||||
# Create a mock CodeGenInput with ForLoop block
|
||||
mock_input = CodeGenInput(
|
||||
file_name="test_workflow.py",
|
||||
workflow_run={"workflow_id": "wpid_123"},
|
||||
workflow={"workflow_definition": {"blocks": []}},
|
||||
workflow_blocks=[
|
||||
{
|
||||
"block_type": "for_loop",
|
||||
"label": "process_urls",
|
||||
"loop_variable_reference": "{{ urls }}",
|
||||
"workflow_run_id": "wr_123",
|
||||
"workflow_run_block_id": "wfrb_456",
|
||||
"loop_blocks": [
|
||||
{
|
||||
"block_type": "extraction",
|
||||
"label": "extract_data",
|
||||
"data_extraction_goal": "Get page content",
|
||||
"task_id": "task_789",
|
||||
"status": "completed",
|
||||
"output": {"content": "extracted data"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
actions_by_task={
|
||||
"task_789": [
|
||||
{
|
||||
"action_type": "extract",
|
||||
"action_id": "action_123",
|
||||
"xpath": "//div[@id='content']",
|
||||
}
|
||||
]
|
||||
},
|
||||
task_v2_child_blocks={},
|
||||
)
|
||||
|
||||
# Verify the structure
|
||||
assert len(mock_input.workflow_blocks) == 1
|
||||
assert mock_input.workflow_blocks[0]["block_type"] == "for_loop"
|
||||
assert len(mock_input.workflow_blocks[0]["loop_blocks"]) == 1
|
||||
assert mock_input.workflow_blocks[0]["loop_blocks"][0]["task_id"] == "task_789"
|
||||
assert "task_789" in mock_input.actions_by_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transform_forloop_block_with_mocked_db() -> None:
|
||||
"""
|
||||
Full integration test for ForLoop block transformation with mocked database.
|
||||
|
||||
This test verifies the actual transformation logic in transform_workflow_run.py
|
||||
correctly processes ForLoop blocks and their child blocks.
|
||||
"""
|
||||
# Mock workflow run response
|
||||
mock_workflow_run_resp = MagicMock()
|
||||
mock_workflow_run_resp.run_request = MagicMock()
|
||||
mock_workflow_run_resp.run_request.workflow_id = "wpid_test_123"
|
||||
mock_workflow_run_resp.run_request.model_dump = MagicMock(
|
||||
return_value={"workflow_id": "wpid_test_123", "parameters": {}}
|
||||
)
|
||||
|
||||
# Mock workflow with ForLoop block definition
|
||||
mock_forloop_definition = MagicMock()
|
||||
mock_forloop_definition.block_type = BlockType.FOR_LOOP
|
||||
mock_forloop_definition.label = "process_urls"
|
||||
mock_forloop_definition.loop_variable_reference = "{{ urls }}"
|
||||
mock_forloop_definition.model_dump = MagicMock(
|
||||
return_value={
|
||||
"block_type": "for_loop",
|
||||
"label": "process_urls",
|
||||
"loop_variable_reference": "{{ urls }}",
|
||||
"loop_blocks": [
|
||||
{
|
||||
"block_type": "extraction",
|
||||
"label": "extract_data",
|
||||
"data_extraction_goal": "Get page content",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.model_dump = MagicMock(return_value={"workflow_id": "wf_123", "workflow_definition": {"blocks": []}})
|
||||
mock_workflow.workflow_definition.blocks = [mock_forloop_definition]
|
||||
|
||||
# Mock workflow run blocks - ForLoop parent and extraction child
|
||||
mock_forloop_run_block = MagicMock()
|
||||
mock_forloop_run_block.workflow_run_block_id = "wfrb_forloop_123"
|
||||
mock_forloop_run_block.parent_workflow_run_block_id = None
|
||||
mock_forloop_run_block.block_type = BlockType.FOR_LOOP
|
||||
mock_forloop_run_block.label = "process_urls"
|
||||
mock_forloop_run_block.task_id = None
|
||||
mock_forloop_run_block.created_at = 1
|
||||
|
||||
mock_child_run_block = MagicMock()
|
||||
mock_child_run_block.workflow_run_block_id = "wfrb_child_456"
|
||||
mock_child_run_block.parent_workflow_run_block_id = "wfrb_forloop_123"
|
||||
mock_child_run_block.block_type = "extraction"
|
||||
mock_child_run_block.label = "extract_data"
|
||||
mock_child_run_block.task_id = "task_extraction_789"
|
||||
mock_child_run_block.status = "completed"
|
||||
mock_child_run_block.output = {"extracted": "data"}
|
||||
mock_child_run_block.created_at = 2
|
||||
|
||||
# Mock task for the child block
|
||||
mock_task = MagicMock()
|
||||
mock_task.model_dump = MagicMock(
|
||||
return_value={
|
||||
"task_id": "task_extraction_789",
|
||||
"navigation_goal": "Extract page content",
|
||||
}
|
||||
)
|
||||
|
||||
# Mock action for the task
|
||||
mock_action = MagicMock()
|
||||
mock_action.action_type = "extract"
|
||||
mock_action.model_dump = MagicMock(
|
||||
return_value={
|
||||
"action_type": "extract",
|
||||
"action_id": "action_123",
|
||||
}
|
||||
)
|
||||
mock_action.get_xpath = MagicMock(return_value="//div[@id='content']")
|
||||
mock_action.has_mini_agent = False
|
||||
|
||||
# Set up patches
|
||||
with (
|
||||
patch("skyvern.services.workflow_service.get_workflow_run_response", new_callable=AsyncMock) as mock_get_wfr,
|
||||
patch("skyvern.core.script_generations.transform_workflow_run.app") as mock_app,
|
||||
):
|
||||
mock_get_wfr.return_value = mock_workflow_run_resp
|
||||
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(
|
||||
return_value=[
|
||||
mock_forloop_run_block,
|
||||
mock_child_run_block,
|
||||
]
|
||||
)
|
||||
# B1 optimization: Mock batch methods instead of individual queries
|
||||
mock_task.task_id = "task_extraction_789"
|
||||
mock_action.task_id = "task_extraction_789"
|
||||
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=[mock_action])
|
||||
|
||||
# Call the transformation
|
||||
result = await transform_workflow_run_to_code_gen_input(
|
||||
workflow_run_id="wr_test_123",
|
||||
organization_id="org_test_123",
|
||||
)
|
||||
|
||||
# Verify ForLoop block is included
|
||||
assert len(result.workflow_blocks) == 1
|
||||
forloop_block = result.workflow_blocks[0]
|
||||
assert forloop_block["block_type"] == "for_loop"
|
||||
assert forloop_block["label"] == "process_urls"
|
||||
|
||||
# Verify loop_blocks contain child block with task data
|
||||
loop_blocks = forloop_block.get("loop_blocks", [])
|
||||
assert len(loop_blocks) == 1
|
||||
child_block = loop_blocks[0]
|
||||
assert child_block["label"] == "extract_data"
|
||||
assert child_block.get("task_id") == "task_extraction_789"
|
||||
|
||||
# Verify actions were collected for the child task
|
||||
assert "task_extraction_789" in result.actions_by_task
|
||||
actions = result.actions_by_task["task_extraction_789"]
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["action_type"] == "extract"
|
||||
|
||||
|
||||
class TestForLoopScriptExecution:
|
||||
"""Test that generated ForLoop scripts can be executed."""
|
||||
|
||||
def test_forloop_generates_async_for_statement(self) -> None:
|
||||
"""Verify that ForLoop generates an async for statement."""
|
||||
forloop_block = {
|
||||
"block_type": "for_loop",
|
||||
"label": "iterate_items",
|
||||
"loop_variable_reference": "{{ items_list }}",
|
||||
"complete_if_empty": True,
|
||||
"loop_blocks": [],
|
||||
}
|
||||
|
||||
result = _build_for_loop_statement("iterate_items", forloop_block)
|
||||
|
||||
# Verify it's an async for loop
|
||||
assert isinstance(result, cst.For)
|
||||
assert result.asynchronous is not None # Has asynchronous keyword
|
||||
|
||||
def test_forloop_generates_skyvern_loop_call(self) -> None:
|
||||
"""Verify that ForLoop generates a skyvern.loop() call."""
|
||||
forloop_block = {
|
||||
"block_type": "for_loop",
|
||||
"label": "iterate_items",
|
||||
"loop_variable_reference": "{{ items_list }}",
|
||||
"loop_blocks": [],
|
||||
}
|
||||
|
||||
result = _build_for_loop_statement("iterate_items", forloop_block)
|
||||
|
||||
# The iter should be a Call to skyvern.loop
|
||||
assert isinstance(result.iter, cst.Call)
|
||||
|
||||
# Get the function being called
|
||||
func = result.iter.func
|
||||
assert isinstance(func, cst.Attribute)
|
||||
assert func.attr.value == "loop"
|
||||
478
tests/unit/test_http_block_raw_filter.py
Normal file
478
tests/unit/test_http_block_raw_filter.py
Normal file
@@ -0,0 +1,478 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.exceptions import FailedToFormatJinjaStyleParameter
|
||||
from skyvern.forge.sdk.workflow.models.block import (
|
||||
_JSON_TYPE_MARKER,
|
||||
HttpRequestBlock,
|
||||
_json_type_filter,
|
||||
jinja_sandbox_env,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
|
||||
|
||||
|
||||
class TestJsonTypeFilter:
|
||||
@pytest.mark.parametrize(
|
||||
"value",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
42,
|
||||
19.99,
|
||||
None,
|
||||
[1, 2, 3],
|
||||
{"a": 1, "b": "hello"},
|
||||
"hello",
|
||||
[],
|
||||
{},
|
||||
],
|
||||
)
|
||||
def test_filter_wraps_with_marker(self, value: object) -> None:
|
||||
result = _json_type_filter(value)
|
||||
assert result.startswith(_JSON_TYPE_MARKER)
|
||||
assert result.endswith(_JSON_TYPE_MARKER)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
42,
|
||||
19.99,
|
||||
None,
|
||||
[1, 2, 3],
|
||||
{"a": 1, "b": "hello"},
|
||||
"hello",
|
||||
],
|
||||
)
|
||||
def test_filter_json_is_parseable(self, value: object) -> None:
|
||||
result = _json_type_filter(value)
|
||||
json_part = result[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
assert parsed == value
|
||||
|
||||
def test_filter_handles_datetime(self) -> None:
|
||||
now = datetime(2024, 1, 15, 12, 30, 45)
|
||||
result = _json_type_filter(now)
|
||||
json_part = result[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
assert parsed == "2024-01-15 12:30:45"
|
||||
|
||||
def test_filter_handles_enum(self) -> None:
|
||||
class Status(StrEnum):
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
|
||||
result = _json_type_filter(Status.completed)
|
||||
json_part = result[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
assert parsed == "completed"
|
||||
|
||||
def test_filter_handles_nested_datetime_in_dict(self) -> None:
|
||||
data = {
|
||||
"status": "completed",
|
||||
"downloaded_files": [
|
||||
{"url": "https://example.com/file.pdf", "modified_at": datetime(2024, 1, 15, 12, 30, 45)}
|
||||
],
|
||||
}
|
||||
result = _json_type_filter(data)
|
||||
json_part = result[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
assert parsed["downloaded_files"][0]["modified_at"] == "2024-01-15 12:30:45"
|
||||
|
||||
|
||||
class TestJinjaJsonFilterRegistration:
|
||||
def test_json_filter_is_registered(self) -> None:
|
||||
assert "json" in jinja_sandbox_env.filters
|
||||
assert jinja_sandbox_env.filters["json"] == _json_type_filter
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template,context,expected_json",
|
||||
[
|
||||
("{{ flag | json }}", {"flag": True}, True),
|
||||
("{{ flag | json }}", {"flag": False}, False),
|
||||
("{{ count | json }}", {"count": 42}, 42),
|
||||
("{{ price | json }}", {"price": 19.99}, 19.99),
|
||||
("{{ null_val | json }}", {"null_val": None}, None),
|
||||
("{{ items | json }}", {"items": [1, 2, 3]}, [1, 2, 3]),
|
||||
("{{ data | json }}", {"data": {"a": 1}}, {"a": 1}),
|
||||
("{{ str_val | json }}", {"str_val": "hello"}, "hello"),
|
||||
],
|
||||
)
|
||||
def test_jinja_renders_json_filter(self, template: str, context: dict, expected_json: object) -> None:
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
# The output should have markers
|
||||
assert rendered.startswith(_JSON_TYPE_MARKER)
|
||||
assert rendered.endswith(_JSON_TYPE_MARKER)
|
||||
# Extract and parse JSON
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
assert parsed == expected_json
|
||||
|
||||
def test_json_filter_with_nested_access(self) -> None:
|
||||
template = "{{ data.nested.value | json }}"
|
||||
context = {"data": {"nested": {"value": True}}}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_part) is True
|
||||
|
||||
def test_json_filter_with_list_index(self) -> None:
|
||||
template = "{{ items[1] | json }}"
|
||||
context = {"items": [10, 20, 30]}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_part) == 20
|
||||
|
||||
def test_json_filter_chains_with_default(self) -> None:
|
||||
template = "{{ missing_val | default(false) | json }}"
|
||||
context = {} # missing_val not defined
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_part) is False
|
||||
|
||||
def test_json_filter_chains_with_default_list(self) -> None:
|
||||
template = "{{ items | default([]) | json }}"
|
||||
context = {}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_part) == []
|
||||
|
||||
|
||||
class TestMarkerDetection:
|
||||
def test_marker_detection_simple(self) -> None:
|
||||
wrapped = _json_type_filter(True)
|
||||
assert wrapped.startswith(_JSON_TYPE_MARKER)
|
||||
assert wrapped.endswith(_JSON_TYPE_MARKER)
|
||||
# Simulate the detection logic from _render_templates_in_json
|
||||
json_str = wrapped[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_str) is True
|
||||
|
||||
def test_marker_detection_complex_object(self) -> None:
|
||||
complex_obj = {"users": [{"name": "Alice", "active": True}], "count": 1}
|
||||
wrapped = _json_type_filter(complex_obj)
|
||||
json_str = wrapped[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_str) == complex_obj
|
||||
|
||||
def test_plain_string_not_detected_as_marker(self) -> None:
|
||||
plain_string = "hello world"
|
||||
assert not plain_string.startswith(_JSON_TYPE_MARKER)
|
||||
assert not plain_string.endswith(_JSON_TYPE_MARKER)
|
||||
|
||||
def test_partial_marker_not_detected(self) -> None:
|
||||
start_only = f"{_JSON_TYPE_MARKER}true"
|
||||
end_only = f"true{_JSON_TYPE_MARKER}"
|
||||
assert not (start_only.startswith(_JSON_TYPE_MARKER) and start_only.endswith(_JSON_TYPE_MARKER))
|
||||
assert not (end_only.startswith(_JSON_TYPE_MARKER) and end_only.endswith(_JSON_TYPE_MARKER))
|
||||
|
||||
|
||||
class TestWithoutJsonFilter:
|
||||
def test_standard_template_renders_string(self) -> None:
|
||||
template = "{{ flag }}"
|
||||
context = {"flag": True}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
assert rendered == "True" # Python str(True)
|
||||
assert not rendered.startswith(_JSON_TYPE_MARKER)
|
||||
|
||||
def test_standard_template_integer(self) -> None:
|
||||
template = "{{ count }}"
|
||||
context = {"count": 42}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
assert rendered == "42"
|
||||
assert not rendered.startswith(_JSON_TYPE_MARKER)
|
||||
|
||||
|
||||
class TestEdgeCasesAndLimitations:
|
||||
def test_mixed_template_jinja_output_contains_marker(self) -> None:
|
||||
"""Jinja renders mixed templates with markers embedded in output.
|
||||
|
||||
This verifies what Jinja produces. The actual error handling happens
|
||||
in _render_templates_in_json (tested separately below).
|
||||
"""
|
||||
template = "prefix_{{ flag | json }}_suffix"
|
||||
context = {"flag": True}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
# The output contains the marker because it's mixed with prefix/suffix
|
||||
assert _JSON_TYPE_MARKER in rendered
|
||||
assert rendered.startswith("prefix_")
|
||||
assert rendered.endswith("_suffix")
|
||||
# The marker detection logic (startswith AND endswith) will NOT match
|
||||
assert not (rendered.startswith(_JSON_TYPE_MARKER) and rendered.endswith(_JSON_TYPE_MARKER))
|
||||
|
||||
def test_deeply_nested_structure(self) -> None:
|
||||
template = "{{ data | json }}"
|
||||
context = {
|
||||
"data": {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"items": [1, 2, 3],
|
||||
"active": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
assert parsed == context["data"]
|
||||
assert parsed["level1"]["level2"]["level3"]["active"] is True
|
||||
|
||||
def test_special_characters_in_string_value(self) -> None:
|
||||
template = "{{ text | json }}"
|
||||
context = {"text": 'Hello "World"\nNew line\ttab'}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
assert parsed == context["text"]
|
||||
|
||||
def test_unicode_characters(self) -> None:
|
||||
template = "{{ text | json }}"
|
||||
context = {"text": "Hello \u4e16\u754c \U0001f600"} # Chinese + emoji
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
assert parsed == context["text"]
|
||||
|
||||
def test_empty_values(self) -> None:
|
||||
# Empty string
|
||||
template = "{{ text | json }}"
|
||||
context: dict[str, object] = {"text": ""}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_part) == ""
|
||||
|
||||
# Empty list
|
||||
context = {"text": []}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_part) == []
|
||||
|
||||
# Empty dict
|
||||
context = {"text": {}}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
assert json.loads(json_part) == {}
|
||||
|
||||
|
||||
class TestEmbeddedMarkerErrorHandling:
|
||||
"""Tests that embedded markers (| json mixed with other text) raise clear errors."""
|
||||
|
||||
def test_embedded_marker_raises_error(self) -> None:
|
||||
"""Using | json with prefix/suffix text should raise FailedToFormatJinjaStyleParameter."""
|
||||
now = datetime.now(timezone.utc)
|
||||
output_param = OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key="http_output",
|
||||
description=None,
|
||||
output_parameter_id="output-1",
|
||||
workflow_id="workflow-1",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
block = HttpRequestBlock(
|
||||
label="test-block",
|
||||
url="https://example.com",
|
||||
method="POST",
|
||||
body={"id": "prefix-{{ val | json }}"},
|
||||
output_parameter=output_param,
|
||||
)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.values = {"val": 123}
|
||||
mock_context.secrets = {}
|
||||
mock_context.include_secrets_in_templates = False
|
||||
mock_context.get_block_metadata = MagicMock(return_value={})
|
||||
|
||||
with pytest.raises(FailedToFormatJinjaStyleParameter) as exc_info:
|
||||
block.format_potential_template_parameters(mock_context)
|
||||
|
||||
assert "can only be used for complete value replacement" in str(exc_info.value)
|
||||
|
||||
def test_valid_json_filter_does_not_raise(self) -> None:
|
||||
"""Using | json for complete value replacement should work without error."""
|
||||
now = datetime.now(timezone.utc)
|
||||
output_param = OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key="http_output",
|
||||
description=None,
|
||||
output_parameter_id="output-1",
|
||||
workflow_id="workflow-1",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
block = HttpRequestBlock(
|
||||
label="test-block",
|
||||
url="https://example.com",
|
||||
method="POST",
|
||||
body={"enabled": "{{ flag | json }}", "count": "{{ num | json }}"},
|
||||
output_parameter=output_param,
|
||||
)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.values = {"flag": True, "num": 42}
|
||||
mock_context.secrets = {}
|
||||
mock_context.include_secrets_in_templates = False
|
||||
mock_context.get_block_metadata = MagicMock(return_value={})
|
||||
|
||||
# Should not raise
|
||||
block.format_potential_template_parameters(mock_context)
|
||||
|
||||
# Verify the values were correctly converted to native types
|
||||
assert block.body == {"enabled": True, "count": 42}
|
||||
|
||||
|
||||
class TestWorkflowRunSummary:
|
||||
"""Tests for the workflow_run_summary template variable."""
|
||||
|
||||
def test_build_workflow_run_summary_empty_outputs(self) -> None:
|
||||
"""Test summary with no block outputs."""
|
||||
context = MagicMock()
|
||||
context.workflow_run_id = "wr_123"
|
||||
context.workflow_run_outputs = {}
|
||||
|
||||
# Create a real context to test the method
|
||||
summary = WorkflowRunContext.build_workflow_run_summary(context)
|
||||
|
||||
assert summary["workflow_run_id"] == "wr_123"
|
||||
assert summary["status"] is None
|
||||
assert summary["output"] == {"extracted_information": {}}
|
||||
assert summary["downloaded_files"] == []
|
||||
assert summary["errors"] == []
|
||||
assert summary["failure_reason"] is None
|
||||
|
||||
def test_build_workflow_run_summary_merges_extracted_information(self) -> None:
|
||||
"""Test that output.extracted_information is merged from all blocks."""
|
||||
context = MagicMock()
|
||||
context.workflow_run_id = "wr_456"
|
||||
context.workflow_run_outputs = {
|
||||
"NavigationBlock": {
|
||||
"status": "completed",
|
||||
"extracted_information": {"title": "Example Page"},
|
||||
"errors": [],
|
||||
"downloaded_files": [],
|
||||
},
|
||||
"ExtractionBlock": {
|
||||
"status": "completed",
|
||||
"extracted_information": {"documents": [{"name": "doc1.pdf"}]},
|
||||
"errors": [],
|
||||
"downloaded_files": [],
|
||||
},
|
||||
}
|
||||
|
||||
summary = WorkflowRunContext.build_workflow_run_summary(context)
|
||||
|
||||
# extracted_information is merged from all blocks (flattened, not keyed by block label)
|
||||
assert summary["output"]["extracted_information"] == {
|
||||
"title": "Example Page",
|
||||
"documents": [{"name": "doc1.pdf"}],
|
||||
}
|
||||
|
||||
def test_build_workflow_run_summary_aggregates_downloaded_files(self) -> None:
|
||||
"""Test that downloaded_files are aggregated from all blocks."""
|
||||
context = MagicMock()
|
||||
context.workflow_run_id = "wr_789"
|
||||
context.workflow_run_outputs = {
|
||||
"Block1": {
|
||||
"status": "completed",
|
||||
"downloaded_files": [{"url": "file1.pdf"}],
|
||||
"errors": [],
|
||||
},
|
||||
"Block2": {
|
||||
"status": "completed",
|
||||
"downloaded_files": [{"url": "file2.pdf"}, {"url": "file3.pdf"}],
|
||||
"errors": [],
|
||||
},
|
||||
}
|
||||
|
||||
summary = WorkflowRunContext.build_workflow_run_summary(context)
|
||||
|
||||
assert len(summary["downloaded_files"]) == 3
|
||||
assert {"url": "file1.pdf"} in summary["downloaded_files"]
|
||||
assert {"url": "file2.pdf"} in summary["downloaded_files"]
|
||||
assert {"url": "file3.pdf"} in summary["downloaded_files"]
|
||||
|
||||
def test_build_workflow_run_summary_aggregates_errors(self) -> None:
|
||||
"""Test that errors are aggregated from all blocks."""
|
||||
context = MagicMock()
|
||||
context.workflow_run_id = "wr_errors"
|
||||
context.workflow_run_outputs = {
|
||||
"Block1": {
|
||||
"status": "failed",
|
||||
"errors": [{"message": "Error 1"}],
|
||||
"failure_reason": "Block 1 failed",
|
||||
},
|
||||
"Block2": {
|
||||
"status": "completed",
|
||||
"errors": [{"message": "Warning"}],
|
||||
},
|
||||
}
|
||||
|
||||
summary = WorkflowRunContext.build_workflow_run_summary(context)
|
||||
|
||||
assert len(summary["errors"]) == 2
|
||||
assert {"message": "Error 1"} in summary["errors"]
|
||||
assert {"message": "Warning"} in summary["errors"]
|
||||
assert summary["failure_reason"] == "Block 1 failed"
|
||||
|
||||
def test_build_workflow_run_summary_uses_last_status(self) -> None:
|
||||
"""Test that the last block's status is used."""
|
||||
context = MagicMock()
|
||||
context.workflow_run_id = "wr_status"
|
||||
context.workflow_run_outputs = {
|
||||
"Block1": {"status": "completed", "errors": []},
|
||||
"Block2": {"status": "failed", "errors": []},
|
||||
"Block3": {"status": "completed", "errors": []},
|
||||
}
|
||||
|
||||
summary = WorkflowRunContext.build_workflow_run_summary(context)
|
||||
# Last block's status is used
|
||||
assert summary["status"] == "completed"
|
||||
|
||||
def test_status_converted_to_string_in_summary(self) -> None:
|
||||
"""Test that TaskStatus enum is converted to string in summary."""
|
||||
context = MagicMock()
|
||||
context.workflow_run_id = "wr_enum"
|
||||
context.workflow_run_outputs = {
|
||||
"Block1": {"status": TaskStatus.completed, "errors": []},
|
||||
"Block2": {"status": TaskStatus.timed_out, "errors": []},
|
||||
}
|
||||
|
||||
summary = WorkflowRunContext.build_workflow_run_summary(context)
|
||||
|
||||
assert summary["status"] == "timed_out"
|
||||
assert type(summary["status"]) is str # Not TaskStatus enum
|
||||
|
||||
def test_workflow_run_summary_with_json_filter(self) -> None:
|
||||
"""Test that workflow_run_summary works with | json filter in templates."""
|
||||
template = "{{ workflow_run_summary | json }}"
|
||||
context = {
|
||||
"workflow_run_summary": {
|
||||
"workflow_run_id": "wr_template",
|
||||
"status": "completed",
|
||||
"output": {"extracted_information": {"documents": [{"name": "doc1.pdf"}]}},
|
||||
"downloaded_files": [{"url": "file.pdf"}],
|
||||
"errors": [],
|
||||
"failure_reason": None,
|
||||
}
|
||||
}
|
||||
rendered = jinja_sandbox_env.from_string(template).render(context)
|
||||
json_part = rendered[len(_JSON_TYPE_MARKER) : -len(_JSON_TYPE_MARKER)]
|
||||
parsed = json.loads(json_part)
|
||||
|
||||
assert parsed["workflow_run_id"] == "wr_template"
|
||||
assert parsed["status"] == "completed"
|
||||
assert parsed["output"]["extracted_information"]["documents"] == [{"name": "doc1.pdf"}]
|
||||
assert parsed["downloaded_files"] == [{"url": "file.pdf"}]
|
||||
assert parsed["errors"] == []
|
||||
assert parsed["failure_reason"] is None
|
||||
9
tests/unit/test_id_generation.py
Normal file
9
tests/unit/test_id_generation.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.db import id as id_module
|
||||
|
||||
|
||||
def test_generate_id_uniqueness_with_overflow(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
total_ids = 10000
|
||||
generated_ids = [id_module.generate_id() for _ in range(total_ids)]
|
||||
assert len(set(generated_ids)) == total_ids
|
||||
22
tests/unit/test_llm_response_parsing.py
Normal file
22
tests/unit/test_llm_response_parsing.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.api.llm.utils import _coerce_response_to_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("response", "expected"),
|
||||
[
|
||||
({"page_info": "Select country"}, ({"page_info": "Select country"}, False)),
|
||||
([{"page_info": "First"}, {"page_info": "Second"}], ({"page_info": "First"}, False)),
|
||||
(["text", {"page_info": "First dict"}], ({"page_info": "First dict"}, False)),
|
||||
([1, 2, 3], ({}, True)),
|
||||
("not-a-dict", ({}, True)),
|
||||
([], ({}, True)),
|
||||
],
|
||||
)
|
||||
def test_coerce_response_to_dict_variants(response, expected):
|
||||
try:
|
||||
parsed = _coerce_response_to_dict(response)
|
||||
assert parsed == expected[0]
|
||||
except Exception:
|
||||
assert expected[1]
|
||||
62
tests/unit/test_mcp_block_tools.py
Normal file
62
tests/unit/test_mcp_block_tools.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Tests for MCP block tools (skyvern_block_schema, skyvern_block_validate)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.cli.mcp_tools.blocks import skyvern_block_schema, skyvern_block_validate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_schema_task_redirects_to_navigation() -> None:
|
||||
"""Requesting schema for 'task' should return navigation info with a deprecation warning."""
|
||||
result = await skyvern_block_schema(block_type="task")
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["data"]["block_type"] == "navigation"
|
||||
assert "navigation_goal" in result["data"]["schema"].get("properties", {})
|
||||
assert len(result["warnings"]) > 0
|
||||
assert any("deprecated" in w.lower() for w in result["warnings"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_schema_unknown_type_returns_error() -> None:
|
||||
"""Requesting schema for a nonexistent type should return an error with available types."""
|
||||
result = await skyvern_block_schema(block_type="invalid_xyz")
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["error"] is not None
|
||||
assert "invalid_xyz" in result["error"]["message"]
|
||||
assert "navigation" in result["error"]["hint"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_validate_task_type_warns_deprecated() -> None:
|
||||
"""Validating a 'task' block should succeed with a deprecation warning."""
|
||||
block = {
|
||||
"block_type": "task",
|
||||
"label": "test",
|
||||
"url": "https://example.com",
|
||||
"navigation_goal": "do something",
|
||||
}
|
||||
result = await skyvern_block_validate(block_json=json.dumps(block))
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["data"]["valid"] is True
|
||||
assert len(result["warnings"]) > 0
|
||||
assert any("deprecated" in w.lower() for w in result["warnings"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_schema_no_type_lists_all() -> None:
|
||||
"""Calling without a block_type should list all available types."""
|
||||
result = await skyvern_block_schema(block_type=None)
|
||||
|
||||
assert result["ok"] is True
|
||||
block_types = result["data"]["block_types"]
|
||||
assert "navigation" in block_types
|
||||
assert "extraction" in block_types
|
||||
assert "task" not in block_types
|
||||
assert result["data"]["count"] > 0
|
||||
270
tests/unit/test_multi_field_totp.py
Normal file
270
tests/unit/test_multi_field_totp.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Tests for multi-field TOTP support in script generation."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.core.script_generations.generate_script import _annotate_multi_field_totp_sequence
|
||||
from skyvern.core.script_generations.script_skyvern_page import ScriptSkyvernPage
|
||||
from skyvern.webeye.actions.action_types import ActionType
|
||||
|
||||
|
||||
class TestAnnotateMultiFieldTotpSequence:
|
||||
"""Tests for _annotate_multi_field_totp_sequence function."""
|
||||
|
||||
def test_empty_actions(self) -> None:
|
||||
"""Empty action list returns unchanged."""
|
||||
result = _annotate_multi_field_totp_sequence([])
|
||||
assert result == []
|
||||
|
||||
def test_less_than_4_actions_returns_unchanged(self) -> None:
|
||||
"""Actions with fewer than 4 items return unchanged (minimum for TOTP)."""
|
||||
actions = [
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "1"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "2"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "3"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
# No totp_timing_info should be added
|
||||
for action in result:
|
||||
assert "totp_timing_info" not in action
|
||||
|
||||
def test_4_digit_sequence_gets_annotated(self) -> None:
|
||||
"""4 consecutive single-digit inputs with same field_name get annotated."""
|
||||
actions = [
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp_code", "text": "1"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp_code", "text": "2"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp_code", "text": "3"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp_code", "text": "4"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
for idx, action in enumerate(result):
|
||||
assert "totp_timing_info" in action
|
||||
assert action["totp_timing_info"]["is_totp_sequence"] is True
|
||||
assert action["totp_timing_info"]["action_index"] == idx
|
||||
assert action["totp_timing_info"]["total_digits"] == 4
|
||||
assert action["totp_timing_info"]["field_name"] == "totp_code"
|
||||
|
||||
def test_6_digit_sequence_gets_annotated(self) -> None:
|
||||
"""Standard 6-digit TOTP sequence gets properly annotated."""
|
||||
actions = [
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "1"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "2"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "3"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "4"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "5"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "otp", "text": "6"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
for idx, action in enumerate(result):
|
||||
assert action["totp_timing_info"]["action_index"] == idx
|
||||
assert action["totp_timing_info"]["total_digits"] == 6
|
||||
|
||||
def test_8_digit_sequence_gets_annotated(self) -> None:
|
||||
"""8-digit sequence (some TOTP implementations) gets annotated."""
|
||||
actions = [{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": str(i)} for i in range(8)]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
assert all("totp_timing_info" in a for a in result)
|
||||
assert result[0]["totp_timing_info"]["total_digits"] == 8
|
||||
assert result[7]["totp_timing_info"]["action_index"] == 7
|
||||
|
||||
def test_3_digits_not_annotated(self) -> None:
|
||||
"""3 consecutive digits should NOT be annotated (minimum is 4)."""
|
||||
actions = [
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "1"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "2"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "3"},
|
||||
{"action_type": ActionType.CLICK, "element_id": "submit"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
for action in result:
|
||||
assert "totp_timing_info" not in action
|
||||
|
||||
def test_different_field_names_not_grouped(self) -> None:
|
||||
"""Actions with different field_names should not be grouped together."""
|
||||
actions = [
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp1", "text": "1"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp1", "text": "2"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp2", "text": "3"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp2", "text": "4"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
# Neither sequence has 4+ with same field_name
|
||||
for action in result:
|
||||
assert "totp_timing_info" not in action
|
||||
|
||||
def test_mixed_actions_with_totp_sequence(self) -> None:
|
||||
"""TOTP sequence surrounded by non-TOTP actions still gets annotated."""
|
||||
actions = [
|
||||
{"action_type": ActionType.CLICK, "element_id": "show_totp"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "1"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "2"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "3"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "4"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "5"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "totp", "text": "6"},
|
||||
{"action_type": ActionType.CLICK, "element_id": "submit"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
# First and last actions should not have totp_timing_info
|
||||
assert "totp_timing_info" not in result[0]
|
||||
assert "totp_timing_info" not in result[7]
|
||||
|
||||
# Middle 6 actions should be annotated
|
||||
for idx in range(1, 7):
|
||||
assert "totp_timing_info" in result[idx]
|
||||
assert result[idx]["totp_timing_info"]["action_index"] == idx - 1
|
||||
assert result[idx]["totp_timing_info"]["total_digits"] == 6
|
||||
|
||||
def test_multiple_sequences_in_action_list(self) -> None:
|
||||
"""Multiple separate TOTP sequences in same action list get annotated separately."""
|
||||
actions = [
|
||||
# First sequence - 4 digits
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code1", "text": "1"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code1", "text": "2"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code1", "text": "3"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code1", "text": "4"},
|
||||
# Non-TOTP action breaks the sequence
|
||||
{"action_type": ActionType.CLICK, "element_id": "next"},
|
||||
# Second sequence - 6 digits
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "5"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "6"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "7"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "8"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "9"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code2", "text": "0"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
# First sequence (indices 0-3)
|
||||
for idx in range(4):
|
||||
assert result[idx]["totp_timing_info"]["total_digits"] == 4
|
||||
assert result[idx]["totp_timing_info"]["field_name"] == "code1"
|
||||
|
||||
# Click action (index 4)
|
||||
assert "totp_timing_info" not in result[4]
|
||||
|
||||
# Second sequence (indices 5-10)
|
||||
for idx in range(5, 11):
|
||||
assert result[idx]["totp_timing_info"]["total_digits"] == 6
|
||||
assert result[idx]["totp_timing_info"]["field_name"] == "code2"
|
||||
assert result[idx]["totp_timing_info"]["action_index"] == idx - 5
|
||||
|
||||
def test_non_digit_text_not_annotated(self) -> None:
|
||||
"""Actions with non-digit text should not be considered TOTP."""
|
||||
actions = [
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "field", "text": "a"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "field", "text": "b"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "field", "text": "c"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "field", "text": "d"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
for action in result:
|
||||
assert "totp_timing_info" not in action
|
||||
|
||||
def test_multi_digit_text_not_annotated(self) -> None:
|
||||
"""Actions with multi-digit text should not be considered multi-field TOTP."""
|
||||
actions = [
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "12"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "34"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "56"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "field_name": "code", "text": "78"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
for action in result:
|
||||
assert "totp_timing_info" not in action
|
||||
|
||||
def test_missing_field_name_not_annotated(self) -> None:
|
||||
"""Actions without field_name should not be considered TOTP."""
|
||||
actions = [
|
||||
{"action_type": ActionType.INPUT_TEXT, "text": "1"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "text": "2"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "text": "3"},
|
||||
{"action_type": ActionType.INPUT_TEXT, "text": "4"},
|
||||
]
|
||||
result = _annotate_multi_field_totp_sequence(actions)
|
||||
|
||||
for action in result:
|
||||
assert "totp_timing_info" not in action
|
||||
|
||||
|
||||
class TestGetTotpDigitBasic:
|
||||
"""Basic tests for get_totp_digit in ScriptSkyvernPage."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_skyvern_context(self) -> MagicMock:
|
||||
"""Create a mock skyvern context."""
|
||||
ctx = MagicMock()
|
||||
ctx.workflow_run_id = "wfr_test123"
|
||||
return ctx
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_single_digit(
|
||||
self,
|
||||
mock_skyvern_context: MagicMock,
|
||||
) -> None:
|
||||
"""get_totp_digit should return a single digit string."""
|
||||
# Empty credentials - will fall back to get_actual_value
|
||||
mock_workflow_context = MagicMock()
|
||||
mock_workflow_context.values = {}
|
||||
|
||||
with patch("skyvern.core.script_generations.script_skyvern_page.skyvern_context") as mock_ctx_module:
|
||||
with patch("skyvern.core.script_generations.script_skyvern_page.app") as mock_app:
|
||||
mock_ctx_module.ensure_context.return_value = mock_skyvern_context
|
||||
mock_app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context = AsyncMock(
|
||||
return_value=mock_workflow_context
|
||||
)
|
||||
|
||||
page = MagicMock(spec=ScriptSkyvernPage)
|
||||
page._totp_sequence_cache = {}
|
||||
page.get_actual_value = AsyncMock(return_value="123456")
|
||||
|
||||
result = await ScriptSkyvernPage.get_totp_digit(
|
||||
page,
|
||||
context=MagicMock(),
|
||||
field_name="totp_code",
|
||||
digit_index=0,
|
||||
)
|
||||
|
||||
# Should return a single digit
|
||||
assert len(result) == 1
|
||||
assert result.isdigit()
|
||||
assert result == "1" # First digit of "123456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_correct_digit_index(
|
||||
self,
|
||||
mock_skyvern_context: MagicMock,
|
||||
) -> None:
|
||||
"""get_totp_digit should return the correct digit for the given index."""
|
||||
mock_workflow_context = MagicMock()
|
||||
mock_workflow_context.values = {}
|
||||
|
||||
with patch("skyvern.core.script_generations.script_skyvern_page.skyvern_context") as mock_ctx_module:
|
||||
with patch("skyvern.core.script_generations.script_skyvern_page.app") as mock_app:
|
||||
mock_ctx_module.ensure_context.return_value = mock_skyvern_context
|
||||
mock_app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context = AsyncMock(
|
||||
return_value=mock_workflow_context
|
||||
)
|
||||
|
||||
page = MagicMock(spec=ScriptSkyvernPage)
|
||||
page._totp_sequence_cache = {}
|
||||
page.get_actual_value = AsyncMock(return_value="987654")
|
||||
|
||||
# Test each digit index
|
||||
for idx, expected in enumerate("987654"):
|
||||
result = await ScriptSkyvernPage.get_totp_digit(
|
||||
page,
|
||||
context=MagicMock(),
|
||||
field_name="totp_code",
|
||||
digit_index=idx,
|
||||
)
|
||||
assert result == expected, f"Expected digit {expected} at index {idx}, got {result}"
|
||||
471
tests/unit/test_parallel_verification.py
Normal file
471
tests/unit/test_parallel_verification.py
Normal file
@@ -0,0 +1,471 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.agent import ForgeAgent, SpeculativePlan
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
from skyvern.schemas.steps import AgentStepOutput
|
||||
from skyvern.webeye.actions.actions import ClickAction, CompleteAction, ExtractAction
|
||||
from skyvern.webeye.actions.responses import ActionSuccess
|
||||
from skyvern.webeye.scraper.scraped_page import ScrapedPage
|
||||
from tests.unit.helpers import (
|
||||
make_browser_state,
|
||||
make_organization,
|
||||
make_step,
|
||||
make_task,
|
||||
setup_parallel_verification_mocks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_verification_triggers_data_extraction(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
agent = ForgeAgent()
|
||||
now = datetime.now(UTC)
|
||||
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
|
||||
step_output = AgentStepOutput(action_results=[], actions_and_results=[])
|
||||
step = make_step(
|
||||
now,
|
||||
task,
|
||||
step_id="step-123",
|
||||
status=StepStatus.completed,
|
||||
order=0,
|
||||
output=step_output,
|
||||
)
|
||||
next_step = make_step(
|
||||
now,
|
||||
task,
|
||||
step_id="step-next",
|
||||
status=StepStatus.created,
|
||||
order=1,
|
||||
output=None,
|
||||
)
|
||||
|
||||
complete_action = CompleteAction(reasoning="done", verified=True)
|
||||
extract_action = ExtractAction(
|
||||
reasoning="extract final data",
|
||||
data_extraction_goal=task.data_extraction_goal,
|
||||
data_extraction_schema=task.extracted_information_schema,
|
||||
)
|
||||
extract_action.organization_id = task.organization_id
|
||||
extract_action.workflow_run_id = task.workflow_run_id
|
||||
extract_action.task_id = task.task_id
|
||||
extract_action.step_id = step.step_id
|
||||
extract_action.step_order = step.order
|
||||
extract_action.action_order = 1
|
||||
monkeypatch.setattr(agent, "create_extract_action", AsyncMock(return_value=extract_action))
|
||||
|
||||
extraction_payload = {"quote": "42%"}
|
||||
mocks = setup_parallel_verification_mocks(
|
||||
agent,
|
||||
step=step,
|
||||
task=task,
|
||||
monkeypatch=monkeypatch,
|
||||
next_step=next_step,
|
||||
complete_action=complete_action,
|
||||
handle_action_responses=[
|
||||
[ActionSuccess()],
|
||||
[ActionSuccess(data=extraction_payload)],
|
||||
],
|
||||
extract_action=extract_action,
|
||||
)
|
||||
|
||||
browser_state, scraped_page, page = make_browser_state()
|
||||
|
||||
completed, last_step, next_created_step = await agent._handle_completed_step_with_parallel_verification(
|
||||
organization=organization,
|
||||
task=task,
|
||||
step=step,
|
||||
page=page,
|
||||
browser_state=browser_state,
|
||||
scraped_page=scraped_page,
|
||||
engine=RunEngine.skyvern_v1,
|
||||
)
|
||||
|
||||
assert completed is True
|
||||
assert last_step == step
|
||||
assert next_created_step is None
|
||||
|
||||
assert mocks.handle_action.await_count == 2
|
||||
|
||||
extracted_information = mocks.update_task.await_args.kwargs["extracted_information"]
|
||||
assert extracted_information == extraction_payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_verification_skips_extraction_without_navigation_goal(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
agent = ForgeAgent()
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization, navigation_goal=None)
|
||||
|
||||
step_output = AgentStepOutput(action_results=[], actions_and_results=[])
|
||||
step = make_step(
|
||||
now,
|
||||
task,
|
||||
step_id="step-123",
|
||||
status=StepStatus.completed,
|
||||
order=0,
|
||||
output=step_output,
|
||||
)
|
||||
|
||||
setup_parallel_verification_mocks(
|
||||
agent,
|
||||
step=step,
|
||||
task=task,
|
||||
monkeypatch=monkeypatch,
|
||||
next_step=step,
|
||||
complete_action=CompleteAction(reasoning="done", verified=True),
|
||||
handle_action_responses=[[ActionSuccess()]],
|
||||
)
|
||||
|
||||
run_data_extraction_mock = AsyncMock()
|
||||
monkeypatch.setattr(agent, "_run_data_extraction_after_complete_action", run_data_extraction_mock)
|
||||
|
||||
browser_state, scraped_page, page = make_browser_state()
|
||||
|
||||
await agent._handle_completed_step_with_parallel_verification(
|
||||
organization=organization,
|
||||
task=task,
|
||||
step=step,
|
||||
page=page,
|
||||
browser_state=browser_state,
|
||||
scraped_page=scraped_page,
|
||||
engine=RunEngine.skyvern_v1,
|
||||
)
|
||||
|
||||
run_data_extraction_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_task_validate_update_requires_extracted_information() -> None:
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
data_extraction_goal="Need data",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
task.validate_update(TaskStatus.completed, extracted_information=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_step_skips_user_goal_check_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
agent = ForgeAgent()
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization, navigation_goal="Reach confirmation page", workflow_run_id="workflow-1")
|
||||
step = make_step(
|
||||
now,
|
||||
task,
|
||||
step_id="step-disable",
|
||||
status=StepStatus.created,
|
||||
order=0,
|
||||
output=None,
|
||||
)
|
||||
|
||||
browser_state, _, page = make_browser_state()
|
||||
browser_state.must_get_working_page = AsyncMock(return_value=page)
|
||||
browser_state.get_working_page = AsyncMock(return_value=page)
|
||||
|
||||
async def _dummy_cleanup(*_args, **_kwargs) -> list[dict]:
|
||||
return []
|
||||
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[],
|
||||
element_tree_trimmed=[],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=_dummy_cleanup,
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
scraped_page.screenshots = [b"image"]
|
||||
|
||||
agent.build_and_record_step_prompt = AsyncMock(return_value=(scraped_page, "prompt", False, "extract-actions"))
|
||||
json_response: dict[str, object] = {"actions": [{"action_type": "CLICK", "element_id": "node-1"}]}
|
||||
agent.handle_potential_OTP_actions = AsyncMock(return_value=(json_response, []))
|
||||
|
||||
click_action = ClickAction(
|
||||
element_id="node-1",
|
||||
organization_id=task.organization_id,
|
||||
workflow_run_id=task.workflow_run_id,
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
action_order=0,
|
||||
)
|
||||
monkeypatch.setattr("skyvern.forge.agent.parse_actions", lambda *_, **__: [click_action])
|
||||
|
||||
action_handler_mock = AsyncMock(return_value=[ActionSuccess()])
|
||||
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", action_handler_mock)
|
||||
agent.record_artifacts_after_action = AsyncMock()
|
||||
agent._is_multi_field_totp_sequence = MagicMock(return_value=False)
|
||||
agent.check_user_goal_complete = AsyncMock()
|
||||
|
||||
llm_handler_mock = AsyncMock(return_value=json_response)
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.agent.LLMAPIHandlerFactory.get_override_llm_api_handler",
|
||||
lambda *_args, **_kwargs: llm_handler_mock,
|
||||
)
|
||||
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.prepare_step_execution", AsyncMock())
|
||||
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.post_action_execution", AsyncMock())
|
||||
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", AsyncMock(return_value=None))
|
||||
monkeypatch.setattr("skyvern.forge.agent.random.uniform", lambda *_args, **_kwargs: 0)
|
||||
|
||||
async def fake_update_step(
|
||||
step: Step,
|
||||
status: StepStatus | None = None,
|
||||
output=None,
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
**_kwargs,
|
||||
) -> Step:
|
||||
if status is not None:
|
||||
step.status = status
|
||||
if output is not None:
|
||||
step.output = output
|
||||
if is_last is not None:
|
||||
step.is_last = is_last
|
||||
if retry_index is not None:
|
||||
step.retry_index = retry_index
|
||||
return step
|
||||
|
||||
agent.update_step = AsyncMock(side_effect=fake_update_step)
|
||||
|
||||
async def feature_flag_side_effect(flag_name: str, *_args, **_kwargs) -> bool:
|
||||
if flag_name == "DISABLE_USER_GOAL_CHECK":
|
||||
return True
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.agent.app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached",
|
||||
AsyncMock(side_effect=feature_flag_side_effect),
|
||||
)
|
||||
|
||||
context = SkyvernContext(
|
||||
task_id=task.task_id,
|
||||
step_id=None,
|
||||
organization_id=task.organization_id,
|
||||
workflow_run_id=task.workflow_run_id,
|
||||
tz_info=ZoneInfo("UTC"),
|
||||
)
|
||||
skyvern_context.set(context)
|
||||
try:
|
||||
completed_step, detailed_output = await agent.agent_step(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=browser_state,
|
||||
organization=organization,
|
||||
)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
assert completed_step.status == StepStatus.completed
|
||||
assert detailed_output.actions_and_results is not None
|
||||
assert action_handler_mock.await_count == 1
|
||||
agent.record_artifacts_after_action.assert_awaited()
|
||||
agent.check_user_goal_complete.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_step_persists_artifacts_when_using_speculative_plan(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
agent = ForgeAgent()
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization, navigation_goal=None)
|
||||
step = make_step(
|
||||
now,
|
||||
task,
|
||||
step_id="step-speculative",
|
||||
status=StepStatus.created,
|
||||
order=0,
|
||||
output=None,
|
||||
)
|
||||
|
||||
browser_state, _, page = make_browser_state()
|
||||
browser_state.must_get_working_page = AsyncMock(return_value=page)
|
||||
browser_state.get_working_page = AsyncMock(return_value=page)
|
||||
|
||||
async def _dummy_cleanup(*_args, **_kwargs) -> list[dict]:
|
||||
return []
|
||||
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[{"tagName": "div", "children": []}],
|
||||
element_tree_trimmed=[{"tagName": "div", "children": []}],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=_dummy_cleanup,
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
scraped_page.html = "<html></html>"
|
||||
scraped_page.id_to_css_dict = {"node-1": "#node"}
|
||||
scraped_page.id_to_frame_dict = {"node-1": "frame-1"}
|
||||
scraped_page.screenshots = [b"image"]
|
||||
|
||||
speculative_plan = SpeculativePlan(
|
||||
scraped_page=scraped_page,
|
||||
extract_action_prompt="unused",
|
||||
use_caching=False,
|
||||
llm_json_response=None,
|
||||
llm_metadata=None,
|
||||
prompt_name="extract-actions",
|
||||
)
|
||||
|
||||
extract_action = ExtractAction(
|
||||
reasoning="collect data",
|
||||
data_extraction_goal=task.data_extraction_goal,
|
||||
data_extraction_schema=task.extracted_information_schema,
|
||||
)
|
||||
extract_action.organization_id = task.organization_id
|
||||
extract_action.workflow_run_id = task.workflow_run_id
|
||||
extract_action.task_id = task.task_id
|
||||
extract_action.step_id = step.step_id
|
||||
extract_action.step_order = step.order
|
||||
extract_action.action_order = 0
|
||||
|
||||
agent.create_extract_action = AsyncMock(return_value=extract_action)
|
||||
agent.record_artifacts_after_action = AsyncMock()
|
||||
agent._persist_scrape_artifacts = AsyncMock()
|
||||
agent._is_multi_field_totp_sequence = MagicMock(return_value=False)
|
||||
|
||||
action_handler_mock = AsyncMock(return_value=[ActionSuccess()])
|
||||
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", action_handler_mock)
|
||||
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.prepare_step_execution", AsyncMock())
|
||||
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.post_action_execution", AsyncMock())
|
||||
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", AsyncMock(return_value=None))
|
||||
monkeypatch.setattr("skyvern.forge.agent.random.uniform", lambda *_args, **_kwargs: 0)
|
||||
monkeypatch.setattr("skyvern.forge.agent.app.DATABASE.create_action", AsyncMock())
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.agent.app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached",
|
||||
AsyncMock(return_value=False),
|
||||
)
|
||||
|
||||
async def fake_update_step(
|
||||
step: Step,
|
||||
status: StepStatus | None = None,
|
||||
output=None,
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
**_kwargs,
|
||||
) -> Step:
|
||||
if status is not None:
|
||||
step.status = status
|
||||
if output is not None:
|
||||
step.output = output
|
||||
if is_last is not None:
|
||||
step.is_last = is_last
|
||||
if retry_index is not None:
|
||||
step.retry_index = retry_index
|
||||
return step
|
||||
|
||||
agent.update_step = AsyncMock(side_effect=fake_update_step)
|
||||
|
||||
context = SkyvernContext(
|
||||
task_id=task.task_id,
|
||||
step_id=None,
|
||||
organization_id=task.organization_id,
|
||||
workflow_run_id=task.workflow_run_id,
|
||||
tz_info=ZoneInfo("UTC"),
|
||||
)
|
||||
context.speculative_plans[step.step_id] = speculative_plan
|
||||
skyvern_context.set(context)
|
||||
try:
|
||||
completed_step, detailed_output = await agent.agent_step(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=browser_state,
|
||||
organization=organization,
|
||||
)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
assert completed_step.status == StepStatus.completed
|
||||
assert detailed_output.actions is not None
|
||||
agent._persist_scrape_artifacts.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_scrape_artifacts_records_all_files(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
agent = ForgeAgent()
|
||||
now = datetime.now(UTC)
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization)
|
||||
step = make_step(
|
||||
now,
|
||||
task,
|
||||
step_id="step-artifacts",
|
||||
status=StepStatus.created,
|
||||
order=0,
|
||||
output=None,
|
||||
)
|
||||
|
||||
browser_state, _, _ = make_browser_state()
|
||||
|
||||
async def _dummy_cleanup(*_args, **_kwargs) -> list[dict]:
|
||||
return []
|
||||
|
||||
scraped_page = ScrapedPage(
|
||||
elements=[],
|
||||
element_tree=[{"tagName": "div"}],
|
||||
element_tree_trimmed=[{"tagName": "div"}],
|
||||
_browser_state=browser_state,
|
||||
_clean_up_func=_dummy_cleanup,
|
||||
_scrape_exclude=None,
|
||||
)
|
||||
scraped_page.html = "<html></html>"
|
||||
scraped_page.id_to_css_dict = {"node-1": "#node"}
|
||||
scraped_page.id_to_frame_dict = {"node-1": "frame-1"}
|
||||
scraped_page.element_tree = [{"tagName": "div"}]
|
||||
scraped_page.element_tree_trimmed = [{"tagName": "div"}]
|
||||
|
||||
economy_tree_mock = MagicMock(return_value="<economy>")
|
||||
full_tree_mock = MagicMock(return_value="<full>")
|
||||
|
||||
def economy_wrapper(self, *args, **kwargs):
|
||||
return economy_tree_mock(self, *args, **kwargs)
|
||||
|
||||
def full_wrapper(self, *args, **kwargs):
|
||||
return full_tree_mock(self, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(ScrapedPage, "build_economy_elements_tree", economy_wrapper)
|
||||
monkeypatch.setattr(ScrapedPage, "build_element_tree", full_wrapper)
|
||||
|
||||
artifact_mock = AsyncMock()
|
||||
monkeypatch.setattr("skyvern.forge.agent.app.ARTIFACT_MANAGER.create_artifact", artifact_mock)
|
||||
|
||||
context = SkyvernContext(
|
||||
task_id=task.task_id,
|
||||
step_id=None,
|
||||
organization_id=task.organization_id,
|
||||
workflow_run_id=task.workflow_run_id,
|
||||
tz_info=ZoneInfo("UTC"),
|
||||
)
|
||||
context.enable_speed_optimizations = True
|
||||
|
||||
await agent._persist_scrape_artifacts(
|
||||
task=task,
|
||||
step=step,
|
||||
scraped_page=scraped_page,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert artifact_mock.await_count == 6
|
||||
economy_tree_mock.assert_called_once()
|
||||
full_tree_mock.assert_not_called()
|
||||
last_call = artifact_mock.await_args_list[-1]
|
||||
assert last_call.kwargs["artifact_type"] == ArtifactType.VISIBLE_ELEMENTS_TREE_IN_PROMPT
|
||||
assert last_call.kwargs["data"] == b"<economy>"
|
||||
58
tests/unit/test_prompt_caching_settings.py
Normal file
58
tests/unit/test_prompt_caching_settings.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.agent import (
|
||||
EXTRACT_ACTION_PROMPT_NAME,
|
||||
EXTRACT_ACTION_TEMPLATE,
|
||||
ForgeAgent,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_caching_settings_respect_experiment(monkeypatch):
|
||||
agent = ForgeAgent()
|
||||
context = SkyvernContext(run_id="wr_123", organization_id="org_456")
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.is_feature_enabled_cached.return_value = True
|
||||
monkeypatch.setattr(app, "EXPERIMENTATION_PROVIDER", mock_provider)
|
||||
try:
|
||||
LLMAPIHandlerFactory.set_prompt_caching_settings(None)
|
||||
|
||||
settings = await agent._get_prompt_caching_settings(context)
|
||||
|
||||
assert settings == {
|
||||
EXTRACT_ACTION_PROMPT_NAME: True,
|
||||
EXTRACT_ACTION_TEMPLATE: True,
|
||||
}
|
||||
mock_provider.is_feature_enabled_cached.assert_awaited_once_with(
|
||||
"PROMPT_CACHING_OPTIMIZATION",
|
||||
"wr_123",
|
||||
properties={"organization_id": "org_456"},
|
||||
)
|
||||
|
||||
# Cached on context; no second provider call
|
||||
await agent._get_prompt_caching_settings(context)
|
||||
assert mock_provider.is_feature_enabled_cached.await_count == 1
|
||||
finally:
|
||||
LLMAPIHandlerFactory.set_prompt_caching_settings(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_caching_settings_use_override(monkeypatch):
|
||||
agent = ForgeAgent()
|
||||
context = SkyvernContext(run_id="wr_789", organization_id="org_987")
|
||||
mock_provider = AsyncMock()
|
||||
monkeypatch.setattr(app, "EXPERIMENTATION_PROVIDER", mock_provider)
|
||||
try:
|
||||
LLMAPIHandlerFactory.set_prompt_caching_settings({"extract-actions": True})
|
||||
|
||||
settings = await agent._get_prompt_caching_settings(context)
|
||||
|
||||
assert settings == {"extract-actions": True}
|
||||
mock_provider.is_feature_enabled_cached.assert_not_called()
|
||||
finally:
|
||||
LLMAPIHandlerFactory.set_prompt_caching_settings(None)
|
||||
97
tests/unit/test_sanitization.py
Normal file
97
tests/unit/test_sanitization.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__normal_text() -> None:
|
||||
"""Test that normal text passes through unchanged."""
|
||||
normal_text = "Hello, World! This is a normal PDF text with numbers 123 and symbols @#$%."
|
||||
result = sanitize_postgres_text(normal_text)
|
||||
assert result == normal_text
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__with_nul_bytes() -> None:
|
||||
"""Test that NUL bytes (0x00) are removed."""
|
||||
text_with_nul = "Hello\x00World\x00Test"
|
||||
expected = "HelloWorldTest"
|
||||
result = sanitize_postgres_text(text_with_nul)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__with_control_characters() -> None:
|
||||
"""Test that problematic control characters are removed."""
|
||||
# Test various control characters that should be removed
|
||||
text_with_controls = "Hello\x01\x02\x03World\x08\x0b\x0c\x0e\x1fTest"
|
||||
expected = "HelloWorldTest"
|
||||
result = sanitize_postgres_text(text_with_controls)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__preserve_whitespace() -> None:
|
||||
"""Test that common whitespace characters are preserved."""
|
||||
text_with_whitespace = "Hello\tWorld\nNew Line\rCarriage Return"
|
||||
result = sanitize_postgres_text(text_with_whitespace)
|
||||
assert result == text_with_whitespace
|
||||
assert "\t" in result
|
||||
assert "\n" in result
|
||||
assert "\r" in result
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__empty_string() -> None:
|
||||
"""Test that empty string is handled correctly."""
|
||||
result = sanitize_postgres_text("")
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__mixed_case() -> None:
|
||||
"""Test text with mix of normal text, NUL bytes, and control characters."""
|
||||
mixed_text = "PDF Text\x00with NUL\tbytes\nand\x01control\x08chars\rand normal text."
|
||||
# \r should be preserved as it's a valid whitespace character
|
||||
expected = "PDF Textwith NUL\tbytes\nandcontrolchars\rand normal text."
|
||||
result = sanitize_postgres_text(mixed_text)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__multiple_nul_bytes() -> None:
|
||||
"""Test that multiple consecutive NUL bytes are all removed."""
|
||||
text_with_multiple_nuls = "Start\x00\x00\x00Middle\x00\x00End"
|
||||
expected = "StartMiddleEnd"
|
||||
result = sanitize_postgres_text(text_with_multiple_nuls)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__unicode_text() -> None:
|
||||
"""Test that Unicode characters are preserved."""
|
||||
unicode_text = "中文测试 Unicode: café, naïve, Ω, emoji 😀"
|
||||
result = sanitize_postgres_text(unicode_text)
|
||||
assert result == unicode_text
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__real_world_pdf_scenario() -> None:
|
||||
"""Test a realistic scenario with PDF extraction artifacts."""
|
||||
# Simulate what might come from a PDF extraction
|
||||
pdf_text = "Invoice\x00Number:\t12345\nDate:\t2024-01-01\x00\nTotal:\t$100.00\x01\x02"
|
||||
expected = "InvoiceNumber:\t12345\nDate:\t2024-01-01\nTotal:\t$100.00"
|
||||
result = sanitize_postgres_text(pdf_text)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__only_control_characters() -> None:
|
||||
"""Test string with only problematic characters."""
|
||||
only_controls = "\x00\x01\x02\x03\x08"
|
||||
expected = ""
|
||||
result = sanitize_postgres_text(only_controls)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__preserves_spaces_and_punctuation() -> None:
|
||||
"""Test that normal spaces and punctuation are preserved."""
|
||||
text = "Hello, World! How are you? I'm fine. Test@example.com"
|
||||
result = sanitize_postgres_text(text)
|
||||
assert result == text
|
||||
|
||||
|
||||
def test_sanitize_postgres_text__newlines_and_paragraphs() -> None:
|
||||
"""Test multi-paragraph text with newlines."""
|
||||
multiline_text = "Paragraph 1\n\nParagraph 2\n\nParagraph 3"
|
||||
result = sanitize_postgres_text(multiline_text)
|
||||
assert result == multiline_text
|
||||
assert result.count("\n") == 4
|
||||
756
tests/unit/test_script_generation_race_condition.py
Normal file
756
tests/unit/test_script_generation_race_condition.py
Normal file
@@ -0,0 +1,756 @@
|
||||
"""
|
||||
Tests for script generation race condition (SKY-7653).
|
||||
|
||||
The race condition occurs when script generation runs during workflow execution
|
||||
before all actions have been saved to the database. This results in:
|
||||
1. `generate_workflow_parameters_schema` not finding INPUT_TEXT actions
|
||||
2. No field_name mappings being generated
|
||||
3. Generated script having hardcoded values instead of context.parameters[field_name]
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.core.script_generations import generate_workflow_parameters as gwp
|
||||
from skyvern.core.script_generations.generate_workflow_parameters import (
|
||||
CUSTOM_FIELD_ACTIONS,
|
||||
GeneratedFieldMapping,
|
||||
generate_workflow_parameters_schema,
|
||||
hydrate_input_text_actions_with_field_names,
|
||||
)
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.workflow.service import BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
|
||||
|
||||
def make_input_text_action(
|
||||
task_id: str,
|
||||
action_id: str,
|
||||
text: str,
|
||||
intention: str = "",
|
||||
field_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a mock INPUT_TEXT action dictionary."""
|
||||
action = {
|
||||
"action_type": ActionType.INPUT_TEXT,
|
||||
"action_id": action_id,
|
||||
"task_id": task_id,
|
||||
"text": text,
|
||||
"intention": intention,
|
||||
"element_id": "element_1",
|
||||
"xpath": "//input[@id='test']",
|
||||
}
|
||||
if field_name:
|
||||
action["field_name"] = field_name
|
||||
return action
|
||||
|
||||
|
||||
def make_click_action(task_id: str, action_id: str) -> dict[str, Any]:
|
||||
"""Create a mock CLICK action dictionary."""
|
||||
return {
|
||||
"action_type": ActionType.CLICK,
|
||||
"action_id": action_id,
|
||||
"task_id": task_id,
|
||||
"element_id": "element_2",
|
||||
"xpath": "//button[@id='submit']",
|
||||
}
|
||||
|
||||
|
||||
class TestRaceConditionScenarios:
|
||||
"""Test scenarios that demonstrate the race condition."""
|
||||
|
||||
def test_hydrate_adds_field_name_to_actions(self) -> None:
|
||||
"""Test that hydrate_input_text_actions_with_field_names properly adds field_name."""
|
||||
task_id = "task-123"
|
||||
action_id = "action-456"
|
||||
|
||||
actions_by_task = {
|
||||
task_id: [
|
||||
make_input_text_action(task_id, action_id, "Urdaneta", "Enter facility name"),
|
||||
]
|
||||
}
|
||||
|
||||
field_mappings = {
|
||||
f"{task_id}:{action_id}": "facility_name",
|
||||
}
|
||||
|
||||
result = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
|
||||
|
||||
# The action should now have field_name
|
||||
assert result[task_id][0].get("field_name") == "facility_name"
|
||||
|
||||
def test_hydrate_without_mappings_no_field_name(self) -> None:
|
||||
"""
|
||||
Test that without field mappings, actions don't get field_name added.
|
||||
This simulates what happens when script generation runs before actions are saved.
|
||||
"""
|
||||
task_id = "task-123"
|
||||
action_id = "action-456"
|
||||
|
||||
actions_by_task = {
|
||||
task_id: [
|
||||
make_input_text_action(task_id, action_id, "Urdaneta", "Enter facility name"),
|
||||
]
|
||||
}
|
||||
|
||||
# Empty field mappings - simulates race condition where LLM wasn't called
|
||||
# because no INPUT_TEXT actions were found
|
||||
field_mappings: dict[str, str] = {}
|
||||
|
||||
result = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
|
||||
|
||||
# The action should NOT have field_name
|
||||
assert "field_name" not in result[task_id][0]
|
||||
|
||||
def test_race_condition_empty_actions_produces_empty_schema(self) -> None:
|
||||
"""
|
||||
Test that when no actions are passed, generate_workflow_parameters_schema
|
||||
returns an empty schema. This happens when script generation runs before
|
||||
actions are executed.
|
||||
"""
|
||||
# Empty actions - simulates script generation running before any INPUT_TEXT
|
||||
# actions have been saved to the database
|
||||
actions_by_task: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
# Call the synchronous part that checks for actions
|
||||
# (The async LLM call won't be made because no actions are found)
|
||||
|
||||
# Extract just the action-finding logic
|
||||
custom_field_actions = []
|
||||
for task_id, actions in actions_by_task.items():
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "")
|
||||
if action_type in CUSTOM_FIELD_ACTIONS:
|
||||
custom_field_actions.append(action)
|
||||
|
||||
# With no actions, the schema generator should return empty schema
|
||||
assert len(custom_field_actions) == 0
|
||||
|
||||
def test_race_condition_only_click_actions_no_schema(self) -> None:
|
||||
"""
|
||||
Test that when only CLICK actions are present (before INPUT_TEXT is saved),
|
||||
no field mappings are generated.
|
||||
"""
|
||||
task_id = "task-123"
|
||||
|
||||
# Only CLICK actions - simulates script generation running after CLICK
|
||||
# but before INPUT_TEXT action is saved
|
||||
actions_by_task = {
|
||||
task_id: [
|
||||
make_click_action(task_id, "action-1"),
|
||||
make_click_action(task_id, "action-2"),
|
||||
]
|
||||
}
|
||||
|
||||
custom_field_actions = []
|
||||
for task_id, actions in actions_by_task.items():
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "")
|
||||
if action_type in CUSTOM_FIELD_ACTIONS:
|
||||
custom_field_actions.append(action)
|
||||
|
||||
# No INPUT_TEXT actions found - no schema will be generated
|
||||
assert len(custom_field_actions) == 0
|
||||
|
||||
|
||||
class TestCodeGenerationWithoutFieldName:
|
||||
"""
|
||||
Test that code generation produces hardcoded values when field_name is missing.
|
||||
|
||||
This demonstrates the impact of the race condition on generated code.
|
||||
"""
|
||||
|
||||
def test_action_without_field_name_produces_hardcoded_value(self) -> None:
|
||||
"""
|
||||
When an INPUT_TEXT action doesn't have field_name (due to race condition),
|
||||
the generated code should have a hardcoded value instead of context.parameters.
|
||||
"""
|
||||
action = make_input_text_action(
|
||||
task_id="task-123",
|
||||
action_id="action-456",
|
||||
text="Urdaneta", # This becomes hardcoded
|
||||
intention="Enter facility name",
|
||||
field_name=None, # No field_name due to race condition
|
||||
)
|
||||
|
||||
# The action_handler_body function uses act.get("field_name") to decide
|
||||
# whether to use context.parameters[field_name] or hardcoded value
|
||||
assert action.get("field_name") is None
|
||||
assert action.get("text") == "Urdaneta" # Will be hardcoded
|
||||
|
||||
def test_action_with_field_name_produces_parameter_reference(self) -> None:
|
||||
"""
|
||||
When an INPUT_TEXT action has field_name, the generated code should
|
||||
use context.parameters[field_name].
|
||||
"""
|
||||
action = make_input_text_action(
|
||||
task_id="task-123",
|
||||
action_id="action-456",
|
||||
text="Urdaneta", # Original value (not used in generated code)
|
||||
intention="Enter facility name",
|
||||
field_name="facility_name", # Field name present
|
||||
)
|
||||
|
||||
# The action has field_name, so generated code will use context.parameters
|
||||
assert action.get("field_name") == "facility_name"
|
||||
|
||||
|
||||
class TestFieldMappingGeneration:
|
||||
"""Test the field mapping generation logic."""
|
||||
|
||||
def test_field_mapping_structure(self) -> None:
|
||||
"""Test that GeneratedFieldMapping has the expected structure."""
|
||||
mapping = GeneratedFieldMapping(
|
||||
field_mappings={"action_index_1": "facility_name"},
|
||||
schema_fields={"facility_name": {"type": "str", "description": "The facility name"}},
|
||||
)
|
||||
|
||||
assert mapping.field_mappings["action_index_1"] == "facility_name"
|
||||
assert mapping.schema_fields["facility_name"]["type"] == "str"
|
||||
|
||||
def test_action_index_to_field_mapping_key_format(self) -> None:
|
||||
"""Test that field mapping keys use the correct format: task_id:action_id."""
|
||||
task_id = "task-123"
|
||||
action_id = "action-456"
|
||||
|
||||
# This is the format used in generate_workflow_parameters_schema
|
||||
expected_key = f"{task_id}:{action_id}"
|
||||
assert expected_key == "task-123:action-456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_workflow_parameters_schema_empty_actions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""
|
||||
Integration test: Verify that empty actions result in empty schema.
|
||||
|
||||
This test confirms the race condition behavior - when script generation
|
||||
runs before INPUT_TEXT actions are saved, no field mappings are generated.
|
||||
"""
|
||||
# Mock the prompt engine and LLM handler since we won't reach them
|
||||
# (the function returns early when no custom_field_actions are found)
|
||||
|
||||
actions_by_task: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
schema_code, action_field_mappings = await generate_workflow_parameters_schema(actions_by_task)
|
||||
|
||||
# Should return empty schema
|
||||
assert "pass" in schema_code # Empty schema has `pass`
|
||||
assert action_field_mappings == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_workflow_parameters_schema_with_actions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""
|
||||
Integration test: Verify that when actions are present, LLM is called.
|
||||
|
||||
This confirms that when script generation runs AFTER actions are saved,
|
||||
it properly generates field mappings.
|
||||
"""
|
||||
|
||||
# Mock the LLM call to return a mapping
|
||||
async def mock_generate_field_names_with_llm(custom_field_actions):
|
||||
return GeneratedFieldMapping(
|
||||
field_mappings={"action_index_1": "facility_name"},
|
||||
schema_fields={"facility_name": {"type": "str", "description": "The facility name"}},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gwp, "_generate_field_names_with_llm", mock_generate_field_names_with_llm)
|
||||
|
||||
task_id = "task-123"
|
||||
action_id = "action-456"
|
||||
actions_by_task = {
|
||||
task_id: [
|
||||
make_input_text_action(task_id, action_id, "Urdaneta", "Enter facility name"),
|
||||
]
|
||||
}
|
||||
|
||||
schema_code, action_field_mappings = await generate_workflow_parameters_schema(actions_by_task)
|
||||
|
||||
# Should have generated schema with field
|
||||
assert "facility_name" in schema_code
|
||||
assert "GeneratedWorkflowParameters" in schema_code
|
||||
|
||||
# Should have mapping for our action
|
||||
assert f"{task_id}:{action_id}" in action_field_mappings
|
||||
assert action_field_mappings[f"{task_id}:{action_id}"] == "facility_name"
|
||||
|
||||
|
||||
class TestRaceConditionTimingScenario:
|
||||
"""
|
||||
Document the timing scenario that causes the race condition.
|
||||
|
||||
Timeline:
|
||||
1. T+0s: CLICK action executes, post_action_execution triggered
|
||||
2. T+0.1s: Script generation starts (asyncio.create_task)
|
||||
3. T+0.2s: Script generation queries database for actions - finds only CLICK
|
||||
4. T+0.3s: Script generation completes with no field mappings
|
||||
5. T+6s: INPUT_TEXT action executes, saved to database
|
||||
6. T+6.1s: Another script generation triggered, but first (wrong) script already saved
|
||||
|
||||
The result is a script with hardcoded values like `value = 'Urdaneta'`
|
||||
instead of `value = context.parameters['facility_name']`
|
||||
"""
|
||||
|
||||
def test_timing_scenario_documentation(self) -> None:
|
||||
"""This test documents the race condition scenario."""
|
||||
# Phase 1: After CLICK, before INPUT_TEXT
|
||||
actions_at_time_0 = {
|
||||
"task-123": [
|
||||
make_click_action("task-123", "action-1"),
|
||||
]
|
||||
}
|
||||
|
||||
# At this point, script generation finds no INPUT_TEXT actions
|
||||
input_text_actions = [
|
||||
a for actions in actions_at_time_0.values() for a in actions if a["action_type"] == ActionType.INPUT_TEXT
|
||||
]
|
||||
assert len(input_text_actions) == 0
|
||||
|
||||
# Phase 2: After INPUT_TEXT is saved
|
||||
actions_at_time_6 = {
|
||||
"task-123": [
|
||||
make_click_action("task-123", "action-1"),
|
||||
make_input_text_action("task-123", "action-2", "Urdaneta", "Enter facility name"),
|
||||
]
|
||||
}
|
||||
|
||||
# Now INPUT_TEXT is found - but too late, first script already saved
|
||||
input_text_actions = [
|
||||
a for actions in actions_at_time_6.values() for a in actions if a["action_type"] == ActionType.INPUT_TEXT
|
||||
]
|
||||
assert len(input_text_actions) == 1
|
||||
assert input_text_actions[0]["text"] == "Urdaneta"
|
||||
|
||||
|
||||
class TestFinalizeParameter:
|
||||
"""
|
||||
Tests for the `finalize` parameter in generate_script_if_needed.
|
||||
|
||||
The fix (SKY-7653) uses a smart finalize approach:
|
||||
- Only regenerates if script_gen_had_incomplete_actions flag is set
|
||||
- This avoids unnecessary regeneration costs when script is already complete
|
||||
"""
|
||||
|
||||
def test_finalize_with_incomplete_actions_triggers_regeneration(self) -> None:
|
||||
"""
|
||||
Test that finalize=True with incomplete actions flag triggers regeneration.
|
||||
|
||||
This simulates the logic in generate_script_if_needed when finalize=True
|
||||
and the context has script_gen_had_incomplete_actions=True.
|
||||
"""
|
||||
|
||||
# Simulate workflow definition blocks
|
||||
class MockBlock:
|
||||
def __init__(self, label: str, block_type: str):
|
||||
self.label = label
|
||||
self.block_type = block_type
|
||||
|
||||
workflow_blocks = [
|
||||
MockBlock("login_step", "task"), # Should be in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
MockBlock("search_step", "task"),
|
||||
MockBlock("wait_block", "wait"), # Should NOT be cached
|
||||
]
|
||||
|
||||
# Simulate the finalize logic with incomplete actions flag
|
||||
blocks_to_update: set[str] = set()
|
||||
finalize = True
|
||||
context = SkyvernContext(script_gen_had_incomplete_actions=True)
|
||||
|
||||
if finalize and context.script_gen_had_incomplete_actions:
|
||||
task_block_labels = {
|
||||
block.label
|
||||
for block in workflow_blocks
|
||||
if block.label and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
}
|
||||
blocks_to_update.update(task_block_labels)
|
||||
|
||||
# Should include task blocks but not wait block
|
||||
assert "login_step" in blocks_to_update
|
||||
assert "search_step" in blocks_to_update
|
||||
assert "wait_block" not in blocks_to_update
|
||||
|
||||
def test_finalize_without_incomplete_actions_skips_regeneration(self) -> None:
|
||||
"""
|
||||
Test that finalize=True without incomplete actions flag skips regeneration.
|
||||
|
||||
This is the optimization - when script generation had complete data,
|
||||
we don't waste resources regenerating.
|
||||
"""
|
||||
blocks_to_update: set[str] = set()
|
||||
finalize = True
|
||||
context = SkyvernContext(script_gen_had_incomplete_actions=False)
|
||||
|
||||
if finalize and context.script_gen_had_incomplete_actions:
|
||||
# This branch won't execute - no incomplete actions
|
||||
blocks_to_update.add("some_block")
|
||||
|
||||
# No blocks should be added - script is already complete
|
||||
assert len(blocks_to_update) == 0
|
||||
|
||||
def test_without_finalize_no_forced_regeneration(self) -> None:
|
||||
"""
|
||||
Test that without finalize=True, blocks are not force-added.
|
||||
"""
|
||||
blocks_to_update: set[str] = set()
|
||||
finalize = False
|
||||
|
||||
# Without finalize, no blocks are force-added
|
||||
if finalize:
|
||||
# This branch won't execute
|
||||
blocks_to_update.add("some_block")
|
||||
|
||||
assert len(blocks_to_update) == 0
|
||||
|
||||
|
||||
class TestCodeGenerationLogic:
|
||||
"""
|
||||
Test the exact code generation logic from generate_script.py.
|
||||
|
||||
The code at generate_script.py:401-429 determines whether to use
|
||||
context.parameters[field_name] or hardcoded text based on act.get("field_name").
|
||||
"""
|
||||
|
||||
def test_code_generation_path_without_field_name(self) -> None:
|
||||
"""
|
||||
Verify the code generation path when field_name is missing.
|
||||
|
||||
From generate_script.py:401-429:
|
||||
- If act.get("field_name") is truthy, use context.parameters[field_name]
|
||||
- Else, use _value(act["text"]) which produces hardcoded string
|
||||
"""
|
||||
action = make_input_text_action(
|
||||
task_id="task-123",
|
||||
action_id="action-456",
|
||||
text="Urdaneta",
|
||||
intention="Enter facility name",
|
||||
field_name=None,
|
||||
)
|
||||
|
||||
# Simulate the code generation logic
|
||||
if action.get("field_name"):
|
||||
# This branch produces: context.parameters["facility_name"]
|
||||
code_path = "context.parameters"
|
||||
else:
|
||||
# This branch produces: "Urdaneta" (hardcoded)
|
||||
code_path = "hardcoded"
|
||||
|
||||
assert code_path == "hardcoded"
|
||||
assert action.get("text") == "Urdaneta"
|
||||
|
||||
def test_code_generation_path_with_field_name(self) -> None:
|
||||
"""
|
||||
Verify the code generation path when field_name is present.
|
||||
|
||||
From generate_script.py:401-429:
|
||||
- If act.get("field_name") is truthy, use context.parameters[field_name]
|
||||
"""
|
||||
action = make_input_text_action(
|
||||
task_id="task-123",
|
||||
action_id="action-456",
|
||||
text="Urdaneta",
|
||||
intention="Enter facility name",
|
||||
field_name="facility_name",
|
||||
)
|
||||
|
||||
# Simulate the code generation logic
|
||||
if action.get("field_name"):
|
||||
# This branch produces: context.parameters["facility_name"]
|
||||
code_path = "context.parameters"
|
||||
else:
|
||||
# This branch produces: "Urdaneta" (hardcoded)
|
||||
code_path = "hardcoded"
|
||||
|
||||
assert code_path == "context.parameters"
|
||||
assert action.get("field_name") == "facility_name"
|
||||
|
||||
def test_demonstrates_race_condition_consequence(self) -> None:
|
||||
"""
|
||||
Demonstrate the consequence of the race condition.
|
||||
|
||||
When script generation runs before INPUT_TEXT action is saved:
|
||||
1. generate_workflow_parameters_schema finds no INPUT_TEXT actions
|
||||
2. No field mappings are generated
|
||||
3. Actions don't get field_name hydrated
|
||||
4. Generated script uses hardcoded values
|
||||
|
||||
This means the cached script CANNOT be reused with different parameters.
|
||||
"""
|
||||
# Scenario: First workflow run with "Urdaneta"
|
||||
# Script generation ran early, field_name is missing
|
||||
action_from_early_script_gen = make_input_text_action(
|
||||
task_id="task-123",
|
||||
action_id="action-456",
|
||||
text="Urdaneta", # This gets hardcoded
|
||||
field_name=None, # Missing due to race condition
|
||||
)
|
||||
|
||||
# The generated code would be: value = "Urdaneta"
|
||||
generated_code_has_hardcoded = action_from_early_script_gen.get("field_name") is None
|
||||
assert generated_code_has_hardcoded
|
||||
|
||||
# Scenario: User runs workflow again with "Pok Pok" parameter
|
||||
# But the cached script has: value = "Urdaneta" (hardcoded!)
|
||||
# So the wrong value is used.
|
||||
|
||||
# Correct scenario: Script generation runs after all actions saved
|
||||
action_from_proper_script_gen = make_input_text_action(
|
||||
task_id="task-123",
|
||||
action_id="action-456",
|
||||
text="Urdaneta",
|
||||
field_name="facility_name", # Present because script gen ran after action saved
|
||||
)
|
||||
|
||||
# The generated code would be: value = context.parameters["facility_name"]
|
||||
generated_code_uses_parameters = action_from_proper_script_gen.get("field_name") is not None
|
||||
assert generated_code_uses_parameters
|
||||
|
||||
# Now when user runs with "Pok Pok", context.parameters["facility_name"] = "Pok Pok"
|
||||
# And the correct value is used!
|
||||
|
||||
|
||||
class TestSkipActionsWithoutData:
|
||||
"""
|
||||
Tests for the smart finalize approach that skips actions without data.
|
||||
|
||||
This addresses the race condition (SKY-7653) while avoiding unnecessary costs:
|
||||
1. Skip actions without data during mid-run generation (avoids bad field mappings)
|
||||
2. Set context flag when actions are skipped (script_gen_had_incomplete_actions)
|
||||
3. At finalize, only regenerate if the flag is set (avoids unnecessary regeneration)
|
||||
|
||||
The benefit is:
|
||||
- First run with race condition: flag set → regenerate at end → script complete
|
||||
- Subsequent runs: script already complete → no regeneration needed
|
||||
"""
|
||||
|
||||
def test_input_text_without_text_is_skipped(self) -> None:
|
||||
"""Test that INPUT_TEXT actions without text are skipped during field mapping."""
|
||||
task_id = "task-123"
|
||||
|
||||
# INPUT_TEXT action without text - simulates race condition
|
||||
action_without_text = {
|
||||
"action_type": ActionType.INPUT_TEXT,
|
||||
"action_id": "action-456",
|
||||
"task_id": task_id,
|
||||
"text": "", # Empty - not yet saved
|
||||
"intention": "Enter facility name",
|
||||
}
|
||||
|
||||
# Simulate the filtering logic from generate_workflow_parameters_schema
|
||||
custom_field_actions = []
|
||||
for action in [action_without_text]:
|
||||
action_type = action.get("action_type", "")
|
||||
if action_type not in CUSTOM_FIELD_ACTIONS:
|
||||
continue
|
||||
|
||||
value = ""
|
||||
if action_type == ActionType.INPUT_TEXT:
|
||||
value = action.get("text", "")
|
||||
|
||||
# Skip actions without data
|
||||
if not value:
|
||||
continue
|
||||
|
||||
custom_field_actions.append(action)
|
||||
|
||||
# Action should be skipped because text is empty
|
||||
assert len(custom_field_actions) == 0
|
||||
|
||||
def test_input_text_with_text_is_included(self) -> None:
|
||||
"""Test that INPUT_TEXT actions with text are included in field mapping."""
|
||||
task_id = "task-123"
|
||||
|
||||
# INPUT_TEXT action with text - properly saved
|
||||
action_with_text = {
|
||||
"action_type": ActionType.INPUT_TEXT,
|
||||
"action_id": "action-456",
|
||||
"task_id": task_id,
|
||||
"text": "Urdaneta", # Has value
|
||||
"intention": "Enter facility name",
|
||||
}
|
||||
|
||||
# Simulate the filtering logic
|
||||
custom_field_actions = []
|
||||
for action in [action_with_text]:
|
||||
action_type = action.get("action_type", "")
|
||||
if action_type not in CUSTOM_FIELD_ACTIONS:
|
||||
continue
|
||||
|
||||
value = ""
|
||||
if action_type == ActionType.INPUT_TEXT:
|
||||
value = action.get("text", "")
|
||||
|
||||
# Skip actions without data
|
||||
if not value:
|
||||
continue
|
||||
|
||||
custom_field_actions.append(action)
|
||||
|
||||
# Action should be included because text has value
|
||||
assert len(custom_field_actions) == 1
|
||||
|
||||
def test_select_option_without_option_is_skipped(self) -> None:
|
||||
"""Test that SELECT_OPTION actions without option are skipped."""
|
||||
task_id = "task-123"
|
||||
|
||||
action_without_option = {
|
||||
"action_type": ActionType.SELECT_OPTION,
|
||||
"action_id": "action-789",
|
||||
"task_id": task_id,
|
||||
"option": "", # Empty - not yet saved
|
||||
}
|
||||
|
||||
custom_field_actions = []
|
||||
for action in [action_without_option]:
|
||||
action_type = action.get("action_type", "")
|
||||
if action_type not in CUSTOM_FIELD_ACTIONS:
|
||||
continue
|
||||
|
||||
value = ""
|
||||
if action_type == ActionType.SELECT_OPTION:
|
||||
value = action.get("option", "")
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
custom_field_actions.append(action)
|
||||
|
||||
assert len(custom_field_actions) == 0
|
||||
|
||||
def test_upload_file_without_file_url_is_skipped(self) -> None:
|
||||
"""Test that UPLOAD_FILE actions without file_url are skipped."""
|
||||
task_id = "task-123"
|
||||
|
||||
action_without_file = {
|
||||
"action_type": ActionType.UPLOAD_FILE,
|
||||
"action_id": "action-101",
|
||||
"task_id": task_id,
|
||||
"file_url": "", # Empty - not yet saved
|
||||
}
|
||||
|
||||
custom_field_actions = []
|
||||
for action in [action_without_file]:
|
||||
action_type = action.get("action_type", "")
|
||||
if action_type not in CUSTOM_FIELD_ACTIONS:
|
||||
continue
|
||||
|
||||
value = ""
|
||||
if action_type == ActionType.UPLOAD_FILE:
|
||||
value = action.get("file_url", "")
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
custom_field_actions.append(action)
|
||||
|
||||
assert len(custom_field_actions) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_workflow_parameters_schema_skips_empty_actions_and_sets_flag(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Integration test: Verify that actions without data are skipped and the context flag is set.
|
||||
|
||||
This test confirms the smart finalize approach:
|
||||
1. Incomplete actions are skipped mid-run (prevents bad field mappings)
|
||||
2. Context flag is set (triggers finalize regeneration only when needed)
|
||||
"""
|
||||
# Set up context to track the flag
|
||||
context = SkyvernContext()
|
||||
skyvern_context.set(context)
|
||||
|
||||
# Mock the LLM call - should only be called if there are valid actions
|
||||
llm_called = False
|
||||
|
||||
async def mock_generate_field_names_with_llm(custom_field_actions):
|
||||
nonlocal llm_called
|
||||
llm_called = True
|
||||
return GeneratedFieldMapping(
|
||||
field_mappings={"action_index_1": "facility_name"},
|
||||
schema_fields={"facility_name": {"type": "str", "description": "The facility name"}},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gwp, "_generate_field_names_with_llm", mock_generate_field_names_with_llm)
|
||||
|
||||
task_id = "task-123"
|
||||
|
||||
# Actions with empty values - simulates race condition
|
||||
actions_by_task = {
|
||||
task_id: [
|
||||
{
|
||||
"action_type": ActionType.INPUT_TEXT,
|
||||
"action_id": "action-456",
|
||||
"task_id": task_id,
|
||||
"text": "", # Empty - not yet saved
|
||||
"intention": "Enter facility name",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
schema_code, action_field_mappings = await generate_workflow_parameters_schema(actions_by_task)
|
||||
|
||||
# LLM should NOT be called because action was skipped
|
||||
assert not llm_called
|
||||
|
||||
# Should return empty schema
|
||||
assert "pass" in schema_code
|
||||
assert action_field_mappings == {}
|
||||
|
||||
# Context flag should be set - triggers finalize regeneration
|
||||
assert context.script_gen_had_incomplete_actions is True
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_workflow_parameters_schema_with_complete_actions_no_flag(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Integration test: Verify that complete actions don't set the context flag.
|
||||
|
||||
When script generation has complete data, the flag should NOT be set,
|
||||
which means finalize won't regenerate (saving costs).
|
||||
"""
|
||||
# Set up context to track the flag
|
||||
context = SkyvernContext()
|
||||
skyvern_context.set(context)
|
||||
|
||||
# Mock the LLM call
|
||||
async def mock_generate_field_names_with_llm(custom_field_actions):
|
||||
return GeneratedFieldMapping(
|
||||
field_mappings={"action_index_1": "facility_name"},
|
||||
schema_fields={"facility_name": {"type": "str", "description": "The facility name"}},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gwp, "_generate_field_names_with_llm", mock_generate_field_names_with_llm)
|
||||
|
||||
task_id = "task-123"
|
||||
|
||||
# Actions with complete values - no race condition
|
||||
actions_by_task = {
|
||||
task_id: [
|
||||
{
|
||||
"action_type": ActionType.INPUT_TEXT,
|
||||
"action_id": "action-456",
|
||||
"task_id": task_id,
|
||||
"text": "Urdaneta", # Has value - complete
|
||||
"intention": "Enter facility name",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
schema_code, action_field_mappings = await generate_workflow_parameters_schema(actions_by_task)
|
||||
|
||||
# Should have generated schema
|
||||
assert "facility_name" in schema_code
|
||||
|
||||
# Context flag should NOT be set - no regeneration needed
|
||||
assert context.script_gen_had_incomplete_actions is False
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
224
tests/unit/test_script_skyvern_page.py
Normal file
224
tests/unit/test_script_skyvern_page.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Unit tests for ScriptSkyvernPage, specifically testing _wait_for_page_ready_before_action.
|
||||
|
||||
This test file exists to prevent regressions like the AttributeError bug where
|
||||
self._page was used instead of self.page (see PR #8425, SKY-7676).
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.core.script_generations.script_skyvern_page import ScriptSkyvernPage
|
||||
|
||||
|
||||
def create_mock_page():
|
||||
"""Create a mock Playwright Page object with required attributes."""
|
||||
page = MagicMock()
|
||||
page.url = "https://example.com"
|
||||
# Required for Playwright Page base class
|
||||
page._loop = MagicMock()
|
||||
page._impl_obj = page
|
||||
return page
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scraped_page():
|
||||
"""Create a mock ScrapedPage object."""
|
||||
scraped_page = MagicMock()
|
||||
scraped_page._browser_state = MagicMock()
|
||||
return scraped_page
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai():
|
||||
"""Create a mock SkyvernPageAi object."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_page_ready_before_action_calls_skyvern_frame(mock_scraped_page, mock_ai):
|
||||
"""
|
||||
Test that _wait_for_page_ready_before_action correctly calls SkyvernFrame.
|
||||
|
||||
This is a regression test for the bug in PR #8273 where self._page was used
|
||||
instead of self.page, causing AttributeError because SkyvernPage stores the
|
||||
Playwright page in self.page.
|
||||
"""
|
||||
mock_page = create_mock_page()
|
||||
|
||||
# Patch the Page base class to avoid Playwright internals
|
||||
with patch(
|
||||
"skyvern.core.script_generations.skyvern_page.Page.__init__",
|
||||
return_value=None,
|
||||
):
|
||||
# Create ScriptSkyvernPage instance
|
||||
script_page = ScriptSkyvernPage(
|
||||
scraped_page=mock_scraped_page,
|
||||
page=mock_page,
|
||||
ai=mock_ai,
|
||||
)
|
||||
|
||||
# Mock SkyvernFrame to verify it's called with self.page
|
||||
mock_skyvern_frame = MagicMock()
|
||||
mock_skyvern_frame.wait_for_page_ready = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.SkyvernFrame.create_instance",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_skyvern_frame,
|
||||
) as mock_create_instance:
|
||||
await script_page._wait_for_page_ready_before_action()
|
||||
|
||||
# Verify SkyvernFrame.create_instance was called exactly once
|
||||
mock_create_instance.assert_called_once()
|
||||
|
||||
# Get the actual call argument
|
||||
call_kwargs = mock_create_instance.call_args.kwargs
|
||||
assert "frame" in call_kwargs, "create_instance should be called with frame argument"
|
||||
# The frame argument should be a MagicMock (the page object)
|
||||
assert call_kwargs["frame"] is not None, "frame should not be None"
|
||||
|
||||
# Verify wait_for_page_ready was called with correct settings
|
||||
mock_skyvern_frame.wait_for_page_ready.assert_called_once_with(
|
||||
network_idle_timeout_ms=settings.PAGE_READY_NETWORK_IDLE_TIMEOUT_MS,
|
||||
loading_indicator_timeout_ms=settings.PAGE_READY_LOADING_INDICATOR_TIMEOUT_MS,
|
||||
dom_stable_ms=settings.PAGE_READY_DOM_STABLE_MS,
|
||||
dom_stability_timeout_ms=settings.PAGE_READY_DOM_STABILITY_TIMEOUT_MS,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_page_ready_before_action_handles_no_page(mock_scraped_page, mock_ai):
|
||||
"""
|
||||
Test that _wait_for_page_ready_before_action returns early if self.page is None.
|
||||
"""
|
||||
# Patch the Page base class to avoid Playwright internals
|
||||
with patch(
|
||||
"skyvern.core.script_generations.skyvern_page.Page.__init__",
|
||||
return_value=None,
|
||||
):
|
||||
# Create a mock page first, then set page to None after construction
|
||||
mock_page = create_mock_page()
|
||||
script_page = ScriptSkyvernPage(
|
||||
scraped_page=mock_scraped_page,
|
||||
page=mock_page,
|
||||
ai=mock_ai,
|
||||
)
|
||||
# Simulate page being None (e.g., after page was closed)
|
||||
script_page.page = None
|
||||
|
||||
# This should return early without raising an error
|
||||
with patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.SkyvernFrame.create_instance",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_instance:
|
||||
await script_page._wait_for_page_ready_before_action()
|
||||
|
||||
# SkyvernFrame.create_instance should NOT be called
|
||||
mock_create_instance.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_page_ready_before_action_catches_exceptions(mock_scraped_page, mock_ai):
|
||||
"""
|
||||
Test that exceptions in _wait_for_page_ready_before_action are caught
|
||||
and don't block action execution.
|
||||
|
||||
This verifies the defensive behavior - page readiness check failures
|
||||
should not prevent actions from executing.
|
||||
"""
|
||||
mock_page = create_mock_page()
|
||||
|
||||
with patch(
|
||||
"skyvern.core.script_generations.skyvern_page.Page.__init__",
|
||||
return_value=None,
|
||||
):
|
||||
script_page = ScriptSkyvernPage(
|
||||
scraped_page=mock_scraped_page,
|
||||
page=mock_page,
|
||||
ai=mock_ai,
|
||||
)
|
||||
|
||||
# Make SkyvernFrame.create_instance raise an exception
|
||||
with patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.SkyvernFrame.create_instance",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Simulated page readiness error"),
|
||||
):
|
||||
# Should NOT raise - exception should be caught
|
||||
await script_page._wait_for_page_ready_before_action()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_page_ready_before_action_catches_wait_for_page_ready_exceptions(mock_scraped_page, mock_ai):
|
||||
"""
|
||||
Test that exceptions from wait_for_page_ready are caught and logged.
|
||||
"""
|
||||
mock_page = create_mock_page()
|
||||
|
||||
with patch(
|
||||
"skyvern.core.script_generations.skyvern_page.Page.__init__",
|
||||
return_value=None,
|
||||
):
|
||||
script_page = ScriptSkyvernPage(
|
||||
scraped_page=mock_scraped_page,
|
||||
page=mock_page,
|
||||
ai=mock_ai,
|
||||
)
|
||||
|
||||
# Make wait_for_page_ready raise an exception
|
||||
mock_skyvern_frame = MagicMock()
|
||||
mock_skyvern_frame.wait_for_page_ready = AsyncMock(side_effect=TimeoutError("Page never became idle"))
|
||||
|
||||
with patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.SkyvernFrame.create_instance",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_skyvern_frame,
|
||||
):
|
||||
# Should NOT raise - exception should be caught
|
||||
await script_page._wait_for_page_ready_before_action()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_page_ready_attribute_access_regression():
|
||||
"""
|
||||
Regression test: Verify that the code accesses self.page, not self._page.
|
||||
|
||||
The original bug (fixed in PR #8425) used self._page which caused:
|
||||
AttributeError: 'ScriptSkyvernPage' object has no attribute '_page'
|
||||
|
||||
This test directly inspects the source code to ensure self._page is not used.
|
||||
"""
|
||||
source = inspect.getsource(ScriptSkyvernPage._wait_for_page_ready_before_action)
|
||||
|
||||
# The fixed code should use self.page
|
||||
assert "self.page" in source, "Method should access self.page"
|
||||
|
||||
# The fixed code should NOT use self._page (except in comments)
|
||||
# Remove comments and docstrings first
|
||||
# Remove docstrings
|
||||
source_no_docstrings = re.sub(r'""".*?"""', "", source, flags=re.DOTALL)
|
||||
source_no_docstrings = re.sub(r"'''.*?'''", "", source_no_docstrings, flags=re.DOTALL)
|
||||
# Remove single-line comments
|
||||
source_no_comments = re.sub(r"#.*$", "", source_no_docstrings, flags=re.MULTILINE)
|
||||
|
||||
# Now check - self._page should NOT appear in the actual code
|
||||
# (It may appear in comments explaining the fix, which is fine)
|
||||
lines_with_code = [
|
||||
line for line in source_no_comments.split("\n") if line.strip() and not line.strip().startswith("#")
|
||||
]
|
||||
code_only = "\n".join(lines_with_code)
|
||||
|
||||
# Check for the bug pattern
|
||||
if "self._page" in code_only:
|
||||
# Find the line for better error reporting
|
||||
for i, line in enumerate(source.split("\n"), 1):
|
||||
if "self._page" in line and not line.strip().startswith("#"):
|
||||
pytest.fail(
|
||||
f"Found 'self._page' in code at line {i}: {line.strip()}\n"
|
||||
"This is a regression! SkyvernPage uses self.page, not self._page."
|
||||
)
|
||||
31
tests/unit/test_secret_credentials.py
Normal file
31
tests/unit/test_secret_credentials.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.schemas.credentials import (
|
||||
CreateCredentialRequest,
|
||||
CredentialType,
|
||||
SecretCredential,
|
||||
)
|
||||
|
||||
|
||||
class TestSecretCredentialModels:
|
||||
def test_secret_credential_creation(self) -> None:
|
||||
cred = SecretCredential(secret_value="sk-abc123", secret_label="API Key")
|
||||
assert cred.secret_value == "sk-abc123"
|
||||
assert cred.secret_label == "API Key"
|
||||
|
||||
def test_secret_credential_optional_type(self) -> None:
|
||||
cred = SecretCredential(secret_value="token123")
|
||||
assert cred.secret_label is None
|
||||
|
||||
def test_non_empty_validation(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
SecretCredential(secret_value="")
|
||||
|
||||
def test_create_request_with_secret(self) -> None:
|
||||
req = CreateCredentialRequest(
|
||||
name="My API Key",
|
||||
credential_type=CredentialType.SECRET,
|
||||
credential=SecretCredential(secret_value="sk-12345"),
|
||||
)
|
||||
assert req.credential_type == CredentialType.SECRET
|
||||
assert req.credential.secret_value == "sk-12345"
|
||||
13
tests/unit/test_security.py
Normal file
13
tests/unit/test_security.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from skyvern.forge.sdk.core.security import create_access_token, generate_skyvern_webhook_signature
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping test_generate_skyvern_signature")
|
||||
@freeze_time("2023-11-30 00:00:00")
|
||||
def test_generate_skyvern_signature() -> None:
|
||||
api_key = create_access_token("o_12345")
|
||||
payload = {"task_id": "t_12345", "float": 1.0}
|
||||
signed_data = generate_skyvern_webhook_signature(payload, api_key)
|
||||
assert signed_data.signature == "1fac4204e1abc7cb0bdf1a42eb17d27f6f1feba065d5726777d5eb77581298c1"
|
||||
210
tests/unit/test_text_prompt_block.py
Normal file
210
tests/unit/test_text_prompt_block.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.workflow.models.block import TextPromptBlock
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
|
||||
|
||||
block_module = sys.modules["skyvern.forge.sdk.workflow.models.block"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "expected_llm_key"),
|
||||
[
|
||||
("gemini-2.5-flash", "VERTEX_GEMINI_2.5_FLASH"),
|
||||
("gemini-3-pro-preview", "VERTEX_GEMINI_3.0_PRO"),
|
||||
],
|
||||
)
|
||||
async def test_text_prompt_block_uses_selected_model(monkeypatch, model_name, expected_llm_key):
|
||||
now = datetime.now(timezone.utc)
|
||||
output_parameter = OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key="text_prompt_output",
|
||||
description=None,
|
||||
output_parameter_id="output-1",
|
||||
workflow_id="workflow-1",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
block = TextPromptBlock(
|
||||
label="text-block",
|
||||
llm_key="AZURE_OPENAI_GPT4_1",
|
||||
prompt="Explain the status.",
|
||||
parameters=[],
|
||||
json_schema=None,
|
||||
output_parameter=output_parameter,
|
||||
model={"model_name": model_name},
|
||||
)
|
||||
|
||||
captured: dict[str, str] = {}
|
||||
fake_default_handler = AsyncMock()
|
||||
|
||||
async def fake_resolve_default_llm_handler(*args, **kwargs):
|
||||
return fake_default_handler
|
||||
|
||||
async def fake_handler(*, prompt: str, prompt_name: str, **kwargs):
|
||||
captured["prompt"] = prompt
|
||||
captured["prompt_name"] = prompt_name
|
||||
return {"llm_response": "ok"}
|
||||
|
||||
def fake_get_override_handler(llm_key: str | None, *, default):
|
||||
captured["llm_key"] = llm_key if llm_key else "default"
|
||||
return fake_handler if llm_key else default
|
||||
|
||||
block_module.app.LLM_API_HANDLER = fake_default_handler
|
||||
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
|
||||
monkeypatch.setattr(
|
||||
LLMAPIHandlerFactory,
|
||||
"get_override_llm_api_handler",
|
||||
fake_get_override_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
TextPromptBlock,
|
||||
"_resolve_default_llm_handler",
|
||||
fake_resolve_default_llm_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
prompt_engine,
|
||||
"load_prompt_from_string",
|
||||
lambda template, **kwargs: template,
|
||||
)
|
||||
|
||||
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
|
||||
|
||||
assert captured["llm_key"] == expected_llm_key
|
||||
assert captured["prompt_name"] == "text-prompt"
|
||||
assert response == {"llm_response": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_prompt_block_uses_workflow_handler_when_no_override(monkeypatch):
|
||||
now = datetime.now(timezone.utc)
|
||||
output_parameter = OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key="text_prompt_output",
|
||||
description=None,
|
||||
output_parameter_id="output-2",
|
||||
workflow_id="workflow-1",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
block = TextPromptBlock(
|
||||
label="text-block",
|
||||
llm_key=None,
|
||||
prompt="Summarize status.",
|
||||
parameters=[],
|
||||
json_schema=None,
|
||||
output_parameter=output_parameter,
|
||||
model=None,
|
||||
)
|
||||
|
||||
captured: dict[str, str] = {}
|
||||
fake_secondary_handler = AsyncMock(return_value={"llm_response": "secondary"})
|
||||
|
||||
async def fake_prompt_type_handler(*args, **kwargs):
|
||||
return None
|
||||
|
||||
def fake_get_override_handler(llm_key: str | None, *, default):
|
||||
captured["llm_key"] = llm_key if llm_key else "default"
|
||||
captured["default_handler"] = default
|
||||
return default
|
||||
|
||||
block_module.app.SECONDARY_LLM_API_HANDLER = fake_secondary_handler
|
||||
block_module.app.LLM_API_HANDLER = AsyncMock()
|
||||
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
|
||||
monkeypatch.setattr(
|
||||
LLMAPIHandlerFactory,
|
||||
"get_override_llm_api_handler",
|
||||
fake_get_override_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
block_module,
|
||||
"get_llm_handler_for_prompt_type",
|
||||
fake_prompt_type_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
prompt_engine,
|
||||
"load_prompt_from_string",
|
||||
lambda template, **kwargs: template,
|
||||
)
|
||||
|
||||
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
|
||||
|
||||
assert captured["llm_key"] == "default"
|
||||
assert captured["default_handler"] == fake_secondary_handler
|
||||
fake_secondary_handler.assert_awaited_once()
|
||||
assert response == {"llm_response": "secondary"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_prompt_block_prefers_prompt_type_config_over_secondary(monkeypatch):
|
||||
now = datetime.now(timezone.utc)
|
||||
output_parameter = OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key="text_prompt_output",
|
||||
description=None,
|
||||
output_parameter_id="output-3",
|
||||
workflow_id="workflow-1",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
block = TextPromptBlock(
|
||||
label="text-block",
|
||||
llm_key=None,
|
||||
prompt="Provide summary.",
|
||||
parameters=[],
|
||||
json_schema=None,
|
||||
output_parameter=output_parameter,
|
||||
model=None,
|
||||
)
|
||||
|
||||
captured: dict[str, str] = {}
|
||||
prompt_config_handler = AsyncMock(return_value={"llm_response": "config"})
|
||||
|
||||
async def fake_prompt_type_handler(*args, **kwargs):
|
||||
return prompt_config_handler
|
||||
|
||||
def fake_get_override_handler(llm_key: str | None, *, default):
|
||||
captured["default_handler"] = default
|
||||
return default
|
||||
|
||||
block_module.app.SECONDARY_LLM_API_HANDLER = AsyncMock()
|
||||
block_module.app.LLM_API_HANDLER = AsyncMock()
|
||||
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
|
||||
monkeypatch.setattr(
|
||||
LLMAPIHandlerFactory,
|
||||
"get_override_llm_api_handler",
|
||||
fake_get_override_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
block_module,
|
||||
"get_llm_handler_for_prompt_type",
|
||||
fake_prompt_type_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
prompt_engine,
|
||||
"load_prompt_from_string",
|
||||
lambda template, **kwargs: template,
|
||||
)
|
||||
|
||||
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
|
||||
|
||||
assert captured["default_handler"] == prompt_config_handler
|
||||
prompt_config_handler.assert_awaited_once()
|
||||
assert response == {"llm_response": "config"}
|
||||
71
tests/unit/test_totp_identifier_fallback.py
Normal file
71
tests/unit/test_totp_identifier_fallback.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.schemas.credentials import CredentialVaultType
|
||||
from skyvern.forge.sdk.workflow import context_manager as cm
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.models.block import TaskV2Block
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_credential_parameter_uses_db_totp_identifier(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
db_credential = SimpleNamespace(
|
||||
credential_id="cred-1",
|
||||
organization_id="org-1",
|
||||
vault_type=CredentialVaultType.BITWARDEN,
|
||||
totp_identifier="user@example.com",
|
||||
)
|
||||
|
||||
class FakeCredential:
|
||||
def __init__(self) -> None:
|
||||
self.totp_identifier = None
|
||||
self.totp = None
|
||||
|
||||
def model_dump(self) -> dict:
|
||||
return {}
|
||||
|
||||
class FakeCredentialItem:
|
||||
def __init__(self) -> None:
|
||||
self.credential = FakeCredential()
|
||||
|
||||
class FakeCredentialService:
|
||||
async def get_credential_item(self, _db_credential: object) -> FakeCredentialItem:
|
||||
return FakeCredentialItem()
|
||||
|
||||
class FakeDatabase:
|
||||
async def get_credential(self, credential_id: str, organization_id: str) -> object:
|
||||
assert credential_id == "cred-1"
|
||||
assert organization_id == "org-1"
|
||||
return db_credential
|
||||
|
||||
fake_app = SimpleNamespace(
|
||||
DATABASE=FakeDatabase(),
|
||||
CREDENTIAL_VAULT_SERVICES={CredentialVaultType.BITWARDEN: FakeCredentialService()},
|
||||
)
|
||||
monkeypatch.setattr(cm, "app", fake_app)
|
||||
|
||||
context = WorkflowRunContext(
|
||||
workflow_title="title",
|
||||
workflow_id="wf-1",
|
||||
workflow_permanent_id="wfp-1",
|
||||
workflow_run_id="wr-1",
|
||||
aws_client=SimpleNamespace(),
|
||||
)
|
||||
|
||||
parameter = SimpleNamespace(key="credential_param")
|
||||
organization = SimpleNamespace(organization_id="org-1")
|
||||
|
||||
await context._register_credential_parameter_value("cred-1", parameter, organization)
|
||||
|
||||
assert context.get_credential_totp_identifier("credential_param") == "user@example.com"
|
||||
|
||||
|
||||
def test_task_v2_block_resolves_totp_identifier_from_context() -> None:
|
||||
block = TaskV2Block.model_construct(totp_identifier=None)
|
||||
workflow_run_context = SimpleNamespace(credential_totp_identifiers={"credential_param": "user@example.com"})
|
||||
|
||||
assert block._resolve_totp_identifier(workflow_run_context) == "user@example.com"
|
||||
|
||||
block_with_explicit_totp = TaskV2Block.model_construct(totp_identifier="provided@example.com")
|
||||
assert block_with_explicit_totp._resolve_totp_identifier(workflow_run_context) == "provided@example.com"
|
||||
29
tests/unit/test_url_validators.py
Normal file
29
tests/unit/test_url_validators.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from skyvern.utils.url_validators import encode_url
|
||||
|
||||
|
||||
def test_encode_url_basic():
|
||||
"""Test basic URL encoding with simple path"""
|
||||
url = "https://example.com/path with spaces"
|
||||
expected = "https://example.com/path%20with%20spaces"
|
||||
assert encode_url(url) == expected
|
||||
|
||||
|
||||
def test_encode_url_with_query_params():
|
||||
"""Test URL encoding with query parameters"""
|
||||
url = "https://example.com/search?q=hello world&type=test"
|
||||
expected = "https://example.com/search?q=hello%20world&type=test"
|
||||
assert encode_url(url) == expected
|
||||
|
||||
|
||||
def test_encode_url_with_special_chars():
|
||||
"""Test URL encoding with special characters"""
|
||||
url = "https://example.com/path/with/special#chars?param=value&other=test@123"
|
||||
expected = "https://example.com/path/with/special#chars?param=value&other=test@123"
|
||||
assert encode_url(url) == expected
|
||||
|
||||
|
||||
def test_encode_url_with_pre_encoded_chars():
|
||||
"""Test URL encoding with pre-encoded characters in query parameters"""
|
||||
url = "https://example.com/search?q=hello world&type=test%20test"
|
||||
expected = "https://example.com/search?q=hello%20world&type=test%20test"
|
||||
assert encode_url(url) == expected
|
||||
48
tests/unit/test_utils_templating.py
Normal file
48
tests/unit/test_utils_templating.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.utils.templating import Constants, get_missing_variables
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template,data,expected",
|
||||
[
|
||||
("", {}, set()),
|
||||
("Hello {{ name }}", {"name": "World"}, set()),
|
||||
("Hello {{ name }}", {"age": 30}, {"name"}),
|
||||
("{{ one }}", {"one": 1, "two": 2}, set()), # extra vars allowed
|
||||
# nested (dotted) variables
|
||||
("{{ user.name }}", {"user": {"name": "Alice"}}, set()),
|
||||
("{{ user.name }}", {"user": {"age": 30}}, {"user.name"}),
|
||||
# list access
|
||||
("{{ items[0] }}", {}, {"items"}),
|
||||
("{{ items[0] }}", {"items": [1, 2, 3]}, set()),
|
||||
("{{ items[0] }}", {"items": []}, {"items[0]"}),
|
||||
# deeply nested lists and dicts
|
||||
("{{ data.users[0].name }}", {"data": {"users": [{"name": "Bob"}]}}, set()),
|
||||
("{{ data.users[0].name }}", {"data": {"users": [{}]}}, {"data.users[0].name"}),
|
||||
("{{ data.users[0].name }}", {"data": {}}, {"data.users[0].name"}),
|
||||
],
|
||||
)
|
||||
def test_get_missing_variables(template, data, expected):
|
||||
missing_vars = get_missing_variables(template, data)
|
||||
assert missing_vars == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template,expected",
|
||||
[
|
||||
("{{ var }}", {"var"}),
|
||||
("{{ var.attr }}", {"var.attr"}),
|
||||
("{{ var[0] }}", {"var[0]"}),
|
||||
("{{ var['key'] }}", {"var['key']"}),
|
||||
('{{ var["key"] }}', {'var["key"]'}),
|
||||
("{{ var.attr[0] }}", {"var.attr[0]"}),
|
||||
("No variables here", set()),
|
||||
("{{ var1 }} and {{ var2.attr }}", {"var1", "var2.attr"}),
|
||||
],
|
||||
)
|
||||
def test_regex_missing_variable_pattern(template, expected):
|
||||
matches = set(re.findall(Constants.MissingVariablePattern, template))
|
||||
assert matches == expected
|
||||
176
tests/unit/test_vertex_cache_model_extraction.py
Normal file
176
tests/unit/test_vertex_cache_model_extraction.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Tests for Vertex AI cache model name extraction from LLMRouterConfig.
|
||||
|
||||
This tests the fix for the issue where GEMINI_3_0_FLASH_WITH_FALLBACK was
|
||||
incorrectly using 'gemini-3.0-flash' instead of 'gemini-3-flash-preview'.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockLLMRouterModelConfig:
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockLLMRouterConfig:
|
||||
model_name: str
|
||||
model_list: list
|
||||
main_model_group: str
|
||||
required_env_vars: list = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.required_env_vars is None:
|
||||
self.required_env_vars = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockLLMConfig:
|
||||
model_name: str
|
||||
required_env_vars: list = None
|
||||
litellm_params: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.required_env_vars is None:
|
||||
self.required_env_vars = []
|
||||
|
||||
|
||||
class TestVertexCacheModelExtraction:
|
||||
"""Test that model names are correctly extracted for Vertex AI caching."""
|
||||
|
||||
def _extract_model_name(self, llm_config, resolved_llm_key: str) -> str:
|
||||
"""
|
||||
Mimics the model name extraction logic from _create_vertex_cache_for_task.
|
||||
"""
|
||||
model_name = "gemini-2.5-flash" # Default
|
||||
extracted_name = None
|
||||
|
||||
# For router configs (LLMRouterConfig), extract from model_list primary model FIRST
|
||||
# This must be checked before model_name since router model_name is just an identifier
|
||||
# (e.g., "gemini-3.0-flash-gpt-5-mini-fallback-router"), not an actual Vertex model
|
||||
if hasattr(llm_config, "model_list") and hasattr(llm_config, "main_model_group"):
|
||||
# Find the primary model in model_list by matching main_model_group
|
||||
for model_entry in llm_config.model_list:
|
||||
if model_entry.model_name == llm_config.main_model_group:
|
||||
# Extract actual model name from litellm_params
|
||||
model_param = model_entry.litellm_params.get("model", "")
|
||||
if "vertex_ai/" in model_param:
|
||||
extracted_name = model_param.split("/")[-1]
|
||||
elif model_param.startswith("gemini-"):
|
||||
extracted_name = model_param
|
||||
break
|
||||
|
||||
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
|
||||
if not extracted_name and hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str):
|
||||
if "vertex_ai/" in llm_config.model_name:
|
||||
# Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash"
|
||||
extracted_name = llm_config.model_name.split("/")[-1]
|
||||
elif llm_config.model_name.startswith("gemini-"):
|
||||
# Already in correct format
|
||||
extracted_name = llm_config.model_name
|
||||
|
||||
# For router/fallback configs, extract from api_base or infer from key name
|
||||
if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params:
|
||||
params = llm_config.litellm_params
|
||||
api_base = params.get("api_base") if isinstance(params, dict) else getattr(params, "api_base", None)
|
||||
if api_base and isinstance(api_base, str) and "/models/" in api_base:
|
||||
# Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash"
|
||||
extracted_name = api_base.split("/models/")[-1]
|
||||
|
||||
# For router configs without api_base, infer from the llm_key itself
|
||||
if not extracted_name:
|
||||
# Extract version from llm_key
|
||||
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", resolved_llm_key, re.IGNORECASE)
|
||||
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
|
||||
|
||||
# Determine flavor
|
||||
if "_PRO_" in resolved_llm_key or resolved_llm_key.endswith("_PRO"):
|
||||
extracted_name = f"gemini-{version}-pro"
|
||||
elif "_FLASH_LITE_" in resolved_llm_key or resolved_llm_key.endswith("_FLASH_LITE"):
|
||||
extracted_name = f"gemini-{version}-flash-lite"
|
||||
else:
|
||||
# Default to flash flavor
|
||||
extracted_name = f"gemini-{version}-flash"
|
||||
|
||||
if extracted_name:
|
||||
model_name = extracted_name
|
||||
|
||||
# Normalize model name to the canonical Vertex identifier
|
||||
# Preserve preview suffixes so we don't strip required identifiers (e.g., gemini-3-flash-preview).
|
||||
match = re.search(r"(gemini-\d+(?:\.\d+)?-(?:flash-lite|flash|pro)(?:-preview)?)", model_name, re.IGNORECASE)
|
||||
if match:
|
||||
model_name = match.group(1).lower()
|
||||
|
||||
return model_name
|
||||
|
||||
def test_router_config_extracts_gemini_3_flash_preview(self):
|
||||
"""
|
||||
GEMINI_3_0_FLASH_WITH_FALLBACK should extract 'gemini-3-flash-preview',
|
||||
NOT 'gemini-3.0-flash'.
|
||||
"""
|
||||
# Create a mock router config that matches the real GEMINI_3_0_FLASH_WITH_FALLBACK
|
||||
router_config = MockLLMRouterConfig(
|
||||
model_name="gemini-3.0-flash-gpt-5-mini-fallback-router",
|
||||
model_list=[
|
||||
MockLLMRouterModelConfig(
|
||||
model_name="vertex-gemini-3.0-flash",
|
||||
litellm_params={"model": "vertex_ai/gemini-3-flash-preview"},
|
||||
),
|
||||
MockLLMRouterModelConfig(
|
||||
model_name="gpt-5-mini-fallback",
|
||||
litellm_params={"model": "gpt-5-mini-2025-08-07"},
|
||||
),
|
||||
],
|
||||
main_model_group="vertex-gemini-3.0-flash",
|
||||
)
|
||||
|
||||
model_name = self._extract_model_name(router_config, "GEMINI_3_0_FLASH_WITH_FALLBACK")
|
||||
|
||||
# Should extract the correct model name with -preview suffix
|
||||
assert model_name == "gemini-3-flash-preview", (
|
||||
f"Expected 'gemini-3-flash-preview' but got '{model_name}'. "
|
||||
"The router config should extract from model_list, not infer from llm_key."
|
||||
)
|
||||
|
||||
def test_direct_vertex_config_extracts_correctly(self):
|
||||
"""Direct VERTEX_GEMINI_3.0_FLASH should extract correctly."""
|
||||
direct_config = MockLLMConfig(
|
||||
model_name="vertex_ai/gemini-3-flash-preview",
|
||||
)
|
||||
|
||||
model_name = self._extract_model_name(direct_config, "VERTEX_GEMINI_3.0_FLASH")
|
||||
assert model_name == "gemini-3-flash-preview"
|
||||
|
||||
def test_router_config_extracts_gemini_2_5_flash(self):
|
||||
"""GEMINI_2_5_FLASH_WITH_FALLBACK should extract 'gemini-2.5-flash'."""
|
||||
router_config = MockLLMRouterConfig(
|
||||
model_name="gemini-2.5-flash-gpt-5-mini-fallback-router",
|
||||
model_list=[
|
||||
MockLLMRouterModelConfig(
|
||||
model_name="vertex-gemini-2.5-flash",
|
||||
litellm_params={"model": "vertex_ai/gemini-2.5-flash"},
|
||||
),
|
||||
MockLLMRouterModelConfig(
|
||||
model_name="gpt-5-mini-fallback",
|
||||
litellm_params={"model": "gpt-5-mini-2025-08-07"},
|
||||
),
|
||||
],
|
||||
main_model_group="vertex-gemini-2.5-flash",
|
||||
)
|
||||
|
||||
model_name = self._extract_model_name(router_config, "GEMINI_2_5_FLASH_WITH_FALLBACK")
|
||||
assert model_name == "gemini-2.5-flash"
|
||||
|
||||
def test_fallback_to_llm_key_inference_when_no_model_list(self):
|
||||
"""When there's no model_list, should fall back to llm_key inference."""
|
||||
# A config that doesn't have model_list (not a router config)
|
||||
simple_config = MockLLMConfig(
|
||||
model_name="some-unrelated-name",
|
||||
)
|
||||
|
||||
model_name = self._extract_model_name(simple_config, "GEMINI_2_5_FLASH")
|
||||
# Should fall back to inference from llm_key
|
||||
assert model_name == "gemini-2.5-flash"
|
||||
627
tests/unit/test_workflow_parameter_validation.py
Normal file
627
tests/unit/test_workflow_parameter_validation.py
Normal file
@@ -0,0 +1,627 @@
|
||||
"""Tests for workflow parameter key and block label validation.
|
||||
|
||||
These tests ensure that parameter keys and block labels are valid Python/Jinja2 identifiers,
|
||||
preventing runtime errors like "'State_' is undefined" when using keys like "State_/_Province".
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameterType
|
||||
from skyvern.schemas.workflows import (
|
||||
TaskBlockYAML,
|
||||
WorkflowParameterYAML,
|
||||
sanitize_block_label,
|
||||
sanitize_parameter_key,
|
||||
sanitize_workflow_yaml_with_references,
|
||||
)
|
||||
from skyvern.utils.templating import replace_jinja_reference
|
||||
|
||||
|
||||
class TestParameterKeyValidation:
|
||||
"""Tests for parameter key validation."""
|
||||
|
||||
def test_valid_parameter_key_simple(self) -> None:
|
||||
"""Test that simple valid keys are accepted."""
|
||||
param = WorkflowParameterYAML(
|
||||
key="my_parameter",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
assert param.key == "my_parameter"
|
||||
|
||||
def test_valid_parameter_key_with_numbers(self) -> None:
|
||||
"""Test that keys with numbers (not at start) are accepted."""
|
||||
param = WorkflowParameterYAML(
|
||||
key="param123",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
assert param.key == "param123"
|
||||
|
||||
def test_valid_parameter_key_underscore_prefix(self) -> None:
|
||||
"""Test that keys starting with underscore are accepted."""
|
||||
param = WorkflowParameterYAML(
|
||||
key="_private_param",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
assert param.key == "_private_param"
|
||||
|
||||
def test_valid_parameter_key_single_letter(self) -> None:
|
||||
"""Test that single letter keys are accepted."""
|
||||
param = WorkflowParameterYAML(
|
||||
key="x",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
assert param.key == "x"
|
||||
|
||||
def test_invalid_parameter_key_with_slash(self) -> None:
|
||||
"""Test that keys with '/' are rejected (the main bug case from SKY-7356)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkflowParameterYAML(
|
||||
key="State_/_Province",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid parameter name" in error_msg
|
||||
|
||||
def test_invalid_parameter_key_with_hyphen(self) -> None:
|
||||
"""Test that keys with '-' are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkflowParameterYAML(
|
||||
key="state-or-province",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid parameter name" in error_msg
|
||||
|
||||
def test_invalid_parameter_key_with_dot(self) -> None:
|
||||
"""Test that keys with '.' are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkflowParameterYAML(
|
||||
key="some.property",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid parameter name" in error_msg
|
||||
|
||||
def test_invalid_parameter_key_starts_with_digit(self) -> None:
|
||||
"""Test that keys starting with a digit are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkflowParameterYAML(
|
||||
key="123param",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid parameter name" in error_msg
|
||||
|
||||
def test_invalid_parameter_key_with_space(self) -> None:
|
||||
"""Test that keys with spaces are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkflowParameterYAML(
|
||||
key="my parameter",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "whitespace" in error_msg
|
||||
|
||||
def test_invalid_parameter_key_with_tab(self) -> None:
|
||||
"""Test that keys with tabs are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkflowParameterYAML(
|
||||
key="my\tparameter",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "whitespace" in error_msg
|
||||
|
||||
def test_invalid_parameter_key_with_asterisk(self) -> None:
|
||||
"""Test that keys with '*' are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkflowParameterYAML(
|
||||
key="param*value",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid parameter name" in error_msg
|
||||
|
||||
|
||||
class TestBlockLabelValidation:
|
||||
"""Tests for block label validation."""
|
||||
|
||||
def test_valid_block_label_simple(self) -> None:
|
||||
"""Test that simple valid labels are accepted."""
|
||||
block = TaskBlockYAML(label="my_task", url="https://example.com")
|
||||
assert block.label == "my_task"
|
||||
|
||||
def test_valid_block_label_with_numbers(self) -> None:
|
||||
"""Test that labels with numbers (not at start) are accepted."""
|
||||
block = TaskBlockYAML(label="task123", url="https://example.com")
|
||||
assert block.label == "task123"
|
||||
|
||||
def test_valid_block_label_underscore_prefix(self) -> None:
|
||||
"""Test that labels starting with underscore are accepted."""
|
||||
block = TaskBlockYAML(label="_private_task", url="https://example.com")
|
||||
assert block.label == "_private_task"
|
||||
|
||||
def test_invalid_block_label_with_slash(self) -> None:
|
||||
"""Test that labels with '/' are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TaskBlockYAML(label="task/block", url="https://example.com")
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid label" in error_msg
|
||||
|
||||
def test_invalid_block_label_with_hyphen(self) -> None:
|
||||
"""Test that labels with '-' are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TaskBlockYAML(label="task-block", url="https://example.com")
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid label" in error_msg
|
||||
|
||||
def test_invalid_block_label_starts_with_digit(self) -> None:
|
||||
"""Test that labels starting with a digit are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TaskBlockYAML(label="123task", url="https://example.com")
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid label" in error_msg
|
||||
|
||||
def test_invalid_block_label_empty(self) -> None:
|
||||
"""Test that empty labels are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TaskBlockYAML(label="", url="https://example.com")
|
||||
error_msg = str(exc_info.value)
|
||||
assert "empty" in error_msg.lower()
|
||||
|
||||
def test_invalid_block_label_whitespace_only(self) -> None:
|
||||
"""Test that whitespace-only labels are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TaskBlockYAML(label=" ", url="https://example.com")
|
||||
error_msg = str(exc_info.value)
|
||||
assert "empty" in error_msg.lower()
|
||||
|
||||
def test_invalid_block_label_with_space(self) -> None:
|
||||
"""Test that labels with spaces are rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TaskBlockYAML(label="my task", url="https://example.com")
|
||||
error_msg = str(exc_info.value)
|
||||
assert "not a valid label" in error_msg
|
||||
|
||||
|
||||
class TestSanitizeBlockLabel:
|
||||
"""Tests for the sanitize_block_label function."""
|
||||
|
||||
def test_sanitize_slash(self) -> None:
|
||||
"""Test that slashes are replaced with underscores."""
|
||||
assert sanitize_block_label("State/Province") == "State_Province"
|
||||
|
||||
def test_sanitize_hyphen(self) -> None:
|
||||
"""Test that hyphens are replaced with underscores."""
|
||||
assert sanitize_block_label("my-block") == "my_block"
|
||||
|
||||
def test_sanitize_dot(self) -> None:
|
||||
"""Test that dots are replaced with underscores."""
|
||||
assert sanitize_block_label("block.name") == "block_name"
|
||||
|
||||
def test_sanitize_multiple_special_chars(self) -> None:
|
||||
"""Test that multiple special characters are handled."""
|
||||
assert sanitize_block_label("State_/_Province") == "State_Province"
|
||||
|
||||
def test_sanitize_consecutive_underscores(self) -> None:
|
||||
"""Test that consecutive underscores are collapsed."""
|
||||
assert sanitize_block_label("a__b___c") == "a_b_c"
|
||||
|
||||
def test_sanitize_leading_trailing_underscores(self) -> None:
|
||||
"""Test that leading/trailing underscores are removed."""
|
||||
assert sanitize_block_label("_my_block_") == "my_block"
|
||||
|
||||
def test_sanitize_digit_prefix(self) -> None:
|
||||
"""Test that labels starting with digits get underscore prefix."""
|
||||
assert sanitize_block_label("123abc") == "_123abc"
|
||||
|
||||
def test_sanitize_digit_prefix_after_strip(self) -> None:
|
||||
"""Test that digit prefix is added after stripping underscores."""
|
||||
assert sanitize_block_label("_123abc") == "_123abc"
|
||||
|
||||
def test_sanitize_all_invalid_chars(self) -> None:
|
||||
"""Test that if all chars are invalid, default is returned."""
|
||||
assert sanitize_block_label("///") == "block"
|
||||
|
||||
def test_sanitize_empty_string(self) -> None:
|
||||
"""Test that empty string returns default."""
|
||||
assert sanitize_block_label("") == "block"
|
||||
|
||||
def test_sanitize_valid_label_unchanged(self) -> None:
|
||||
"""Test that valid labels are unchanged."""
|
||||
assert sanitize_block_label("my_valid_label") == "my_valid_label"
|
||||
|
||||
def test_sanitize_spaces(self) -> None:
|
||||
"""Test that spaces are replaced with underscores."""
|
||||
assert sanitize_block_label("my block name") == "my_block_name"
|
||||
|
||||
|
||||
class TestSanitizeParameterKey:
|
||||
"""Tests for the sanitize_parameter_key function."""
|
||||
|
||||
def test_sanitize_slash(self) -> None:
|
||||
"""Test that slashes are replaced with underscores."""
|
||||
assert sanitize_parameter_key("State/Province") == "State_Province"
|
||||
|
||||
def test_sanitize_hyphen(self) -> None:
|
||||
"""Test that hyphens are replaced with underscores."""
|
||||
assert sanitize_parameter_key("my-param") == "my_param"
|
||||
|
||||
def test_sanitize_dot(self) -> None:
|
||||
"""Test that dots are replaced with underscores."""
|
||||
assert sanitize_parameter_key("param.name") == "param_name"
|
||||
|
||||
def test_sanitize_all_invalid_chars(self) -> None:
|
||||
"""Test that if all chars are invalid, default is returned."""
|
||||
assert sanitize_parameter_key("///") == "parameter"
|
||||
|
||||
def test_sanitize_empty_string(self) -> None:
|
||||
"""Test that empty string returns default."""
|
||||
assert sanitize_parameter_key("") == "parameter"
|
||||
|
||||
def test_sanitize_valid_key_unchanged(self) -> None:
|
||||
"""Test that valid keys are unchanged."""
|
||||
assert sanitize_parameter_key("my_valid_key") == "my_valid_key"
|
||||
|
||||
|
||||
class TestReplaceJinjaReference:
|
||||
"""Tests for the replace_jinja_reference function."""
|
||||
|
||||
def test_replace_simple_reference(self) -> None:
|
||||
"""Test replacing a simple Jinja reference."""
|
||||
text = "Value is {{ old_key }}"
|
||||
result = replace_jinja_reference(text, "old_key", "new_key")
|
||||
assert result == "Value is {{ new_key }}"
|
||||
|
||||
def test_replace_reference_no_spaces(self) -> None:
|
||||
"""Test replacing a reference without spaces."""
|
||||
text = "Value is {{old_key}}"
|
||||
result = replace_jinja_reference(text, "old_key", "new_key")
|
||||
assert result == "Value is {{new_key}}"
|
||||
|
||||
def test_replace_reference_with_attribute(self) -> None:
|
||||
"""Test replacing a reference with attribute access."""
|
||||
text = "Value is {{ old_key.field }}"
|
||||
result = replace_jinja_reference(text, "old_key", "new_key")
|
||||
assert result == "Value is {{ new_key.field }}"
|
||||
|
||||
def test_replace_reference_with_filter(self) -> None:
|
||||
"""Test replacing a reference with filter."""
|
||||
text = "Value is {{ old_key | default('') }}"
|
||||
result = replace_jinja_reference(text, "old_key", "new_key")
|
||||
assert result == "Value is {{ new_key | default('') }}"
|
||||
|
||||
def test_replace_reference_with_index(self) -> None:
|
||||
"""Test replacing a reference with index access."""
|
||||
text = "Value is {{ old_key[0] }}"
|
||||
result = replace_jinja_reference(text, "old_key", "new_key")
|
||||
assert result == "Value is {{ new_key[0] }}"
|
||||
|
||||
def test_replace_multiple_references(self) -> None:
|
||||
"""Test replacing multiple occurrences."""
|
||||
text = "{{ old_key }} and {{ old_key.field }}"
|
||||
result = replace_jinja_reference(text, "old_key", "new_key")
|
||||
assert result == "{{ new_key }} and {{ new_key.field }}"
|
||||
|
||||
def test_no_replace_partial_match(self) -> None:
|
||||
"""Test that partial matches are not replaced."""
|
||||
text = "{{ old_key_extended }}"
|
||||
result = replace_jinja_reference(text, "old_key", "new_key")
|
||||
assert result == "{{ old_key_extended }}"
|
||||
|
||||
def test_no_replace_different_key(self) -> None:
|
||||
"""Test that different keys are not affected."""
|
||||
text = "{{ other_key }}"
|
||||
result = replace_jinja_reference(text, "old_key", "new_key")
|
||||
assert result == "{{ other_key }}"
|
||||
|
||||
|
||||
class TestSanitizeWorkflowYamlWithReferences:
|
||||
"""Tests for the sanitize_workflow_yaml_with_references function."""
|
||||
|
||||
def test_sanitize_simple_block_label(self) -> None:
|
||||
"""Test sanitizing a simple block label."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {"parameters": [], "blocks": [{"label": "State/Province", "block_type": "task"}]},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["blocks"][0]["label"] == "State_Province"
|
||||
|
||||
def test_sanitize_updates_output_references(self) -> None:
|
||||
"""Test that output references are updated when label is sanitized."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{"label": "my-block", "block_type": "task"},
|
||||
{
|
||||
"label": "second_block",
|
||||
"block_type": "task",
|
||||
"navigation_goal": "Use {{ my-block_output }} value",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["blocks"][0]["label"] == "my_block"
|
||||
assert "{{ my_block_output }}" in result["workflow_definition"]["blocks"][1]["navigation_goal"]
|
||||
|
||||
def test_sanitize_updates_next_block_label(self) -> None:
|
||||
"""Test that next_block_label is updated when label is sanitized."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{"label": "block-1", "block_type": "task", "next_block_label": "block-2"},
|
||||
{"label": "block-2", "block_type": "task"},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["blocks"][0]["label"] == "block_1"
|
||||
assert result["workflow_definition"]["blocks"][0]["next_block_label"] == "block_2"
|
||||
assert result["workflow_definition"]["blocks"][1]["label"] == "block_2"
|
||||
|
||||
def test_sanitize_updates_finally_block_label(self) -> None:
|
||||
"""Test that finally_block_label is updated when referenced label is sanitized."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [{"label": "cleanup-block", "block_type": "task"}],
|
||||
"finally_block_label": "cleanup-block",
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["blocks"][0]["label"] == "cleanup_block"
|
||||
assert result["workflow_definition"]["finally_block_label"] == "cleanup_block"
|
||||
|
||||
def test_sanitize_nested_loop_blocks(self) -> None:
|
||||
"""Test that nested blocks in for_loop are sanitized."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{
|
||||
"label": "my_loop",
|
||||
"block_type": "for_loop",
|
||||
"loop_blocks": [{"label": "inner-block", "block_type": "task"}],
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["blocks"][0]["loop_blocks"][0]["label"] == "inner_block"
|
||||
|
||||
def test_sanitize_no_changes_needed(self) -> None:
|
||||
"""Test that valid labels are unchanged."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {"parameters": [], "blocks": [{"label": "valid_label", "block_type": "task"}]},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["blocks"][0]["label"] == "valid_label"
|
||||
|
||||
def test_sanitize_empty_workflow_definition(self) -> None:
|
||||
"""Test handling of missing workflow_definition."""
|
||||
workflow_yaml = {"title": "Test Workflow"}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result == workflow_yaml
|
||||
|
||||
def test_sanitize_updates_parameter_references(self) -> None:
|
||||
"""Test that parameter references are updated."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{"key": "my_param", "parameter_type": "context", "source_parameter_key": "block-1_output"}
|
||||
],
|
||||
"blocks": [{"label": "block-1", "block_type": "task"}],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["blocks"][0]["label"] == "block_1"
|
||||
assert result["workflow_definition"]["parameters"][0]["source_parameter_key"] == "block_1_output"
|
||||
|
||||
def test_sanitize_parameter_key(self) -> None:
|
||||
"""Test that parameter keys with invalid characters are sanitized."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{
|
||||
"key": "State/Province",
|
||||
"parameter_type": "workflow",
|
||||
}
|
||||
],
|
||||
"blocks": [],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["parameters"][0]["key"] == "State_Province"
|
||||
|
||||
def test_sanitize_parameter_key_updates_jinja_references(self) -> None:
|
||||
"""Test that Jinja references to sanitized parameter keys are updated."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{
|
||||
"key": "user-input",
|
||||
"parameter_type": "workflow",
|
||||
}
|
||||
],
|
||||
"blocks": [
|
||||
{"label": "my_task", "block_type": "task", "navigation_goal": "Enter {{ user-input }} in the form"}
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["parameters"][0]["key"] == "user_input"
|
||||
assert "{{ user_input }}" in result["workflow_definition"]["blocks"][0]["navigation_goal"]
|
||||
|
||||
def test_sanitize_parameter_key_updates_parameter_keys_array(self) -> None:
|
||||
"""Test that parameter_keys arrays in blocks are updated."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{
|
||||
"key": "my-param",
|
||||
"parameter_type": "workflow",
|
||||
}
|
||||
],
|
||||
"blocks": [{"label": "my_task", "block_type": "task", "parameter_keys": ["my-param", "other_param"]}],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["parameters"][0]["key"] == "my_param"
|
||||
assert result["workflow_definition"]["blocks"][0]["parameter_keys"] == ["my_param", "other_param"]
|
||||
|
||||
def test_sanitize_both_labels_and_parameter_keys(self) -> None:
|
||||
"""Test that both block labels and parameter keys are sanitized together."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{
|
||||
"key": "user/input",
|
||||
"parameter_type": "workflow",
|
||||
}
|
||||
],
|
||||
"blocks": [
|
||||
{
|
||||
"label": "task-1",
|
||||
"block_type": "task",
|
||||
"navigation_goal": "Use {{ user/input }} and {{ task-1_output }}",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["parameters"][0]["key"] == "user_input"
|
||||
assert result["workflow_definition"]["blocks"][0]["label"] == "task_1"
|
||||
nav_goal = result["workflow_definition"]["blocks"][0]["navigation_goal"]
|
||||
assert "{{ user_input }}" in nav_goal
|
||||
assert "{{ task_1_output }}" in nav_goal
|
||||
|
||||
def test_sanitize_block_label_collision(self) -> None:
|
||||
"""Test that block labels that sanitize to the same value get unique suffixes."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{"label": "state/province", "block_type": "task"},
|
||||
{"label": "state-province", "block_type": "task"},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
labels = [b["label"] for b in result["workflow_definition"]["blocks"]]
|
||||
assert labels[0] == "state_province"
|
||||
assert labels[1] == "state_province_2"
|
||||
# Ensure they are unique
|
||||
assert len(set(labels)) == len(labels)
|
||||
|
||||
def test_sanitize_parameter_key_collision(self) -> None:
|
||||
"""Test that parameter keys that sanitize to the same value get unique suffixes."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{"key": "user/input", "parameter_type": "workflow"},
|
||||
{"key": "user-input", "parameter_type": "workflow"},
|
||||
],
|
||||
"blocks": [
|
||||
{
|
||||
"label": "my_task",
|
||||
"block_type": "task",
|
||||
"navigation_goal": "Use {{ user/input }} and {{ user-input }}",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
keys = [p["key"] for p in result["workflow_definition"]["parameters"]]
|
||||
assert keys[0] == "user_input"
|
||||
assert keys[1] == "user_input_2"
|
||||
# Ensure references are updated correctly
|
||||
nav_goal = result["workflow_definition"]["blocks"][0]["navigation_goal"]
|
||||
assert "{{ user_input }}" in nav_goal
|
||||
assert "{{ user_input_2 }}" in nav_goal
|
||||
|
||||
def test_sanitize_collision_with_existing_valid_label(self) -> None:
|
||||
"""Test that sanitized labels don't collide with already-valid labels."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{"label": "my_block", "block_type": "task"},
|
||||
{"label": "my-block", "block_type": "task"},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
labels = [b["label"] for b in result["workflow_definition"]["blocks"]]
|
||||
assert labels[0] == "my_block"
|
||||
assert labels[1] == "my_block_2"
|
||||
|
||||
def test_sanitize_shorthand_block_label_references(self) -> None:
|
||||
"""Test that shorthand block label references ({{ label }} without _output) are also updated."""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{
|
||||
"label": "extract/block",
|
||||
"block_type": "extraction",
|
||||
},
|
||||
{
|
||||
"label": "send_block",
|
||||
"block_type": "send_email",
|
||||
# Both shorthand {{ label }} and full {{ label_output }} patterns
|
||||
"body": "Data: {{ extract/block.extracted_information }} and {{ extract/block_output.status }}",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
assert result["workflow_definition"]["blocks"][0]["label"] == "extract_block"
|
||||
body = result["workflow_definition"]["blocks"][1]["body"]
|
||||
# Shorthand reference should be updated
|
||||
assert "{{ extract_block.extracted_information }}" in body
|
||||
# Full _output reference should also be updated
|
||||
assert "{{ extract_block_output.status }}" in body
|
||||
|
||||
def test_sanitize_label_shorthand_does_not_corrupt_output_ref(self) -> None:
|
||||
"""Ensure shorthand label replacement does not corrupt _output references.
|
||||
|
||||
When a label like 'block-1' is sanitized to 'block_1', both the shorthand
|
||||
{{ block-1 }} and output {{ block-1_output }} patterns must be updated
|
||||
independently without the shorthand replacement corrupting the _output form.
|
||||
"""
|
||||
workflow_yaml = {
|
||||
"title": "Test Workflow",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{
|
||||
"label": "block-1",
|
||||
"block_type": "task",
|
||||
"navigation_goal": "{{ block-1 }} and {{ block-1_output }}",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
result = sanitize_workflow_yaml_with_references(workflow_yaml)
|
||||
goal = result["workflow_definition"]["blocks"][0]["navigation_goal"]
|
||||
assert "{{ block_1 }}" in goal
|
||||
assert "{{ block_1_output }}" in goal
|
||||
600
tests/unit/test_workflow_schema_field_preservation.py
Normal file
600
tests/unit/test_workflow_schema_field_preservation.py
Normal file
@@ -0,0 +1,600 @@
|
||||
"""
|
||||
Tests for workflow schema field name preservation (SKY-7434).
|
||||
|
||||
When a workflow is regenerated (e.g., after adding a new block), the LLM should
|
||||
preserve field names for unchanged blocks to prevent schema mismatches with
|
||||
cached block code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from skyvern.core.script_generations.generate_script import (
|
||||
ScriptBlockSource,
|
||||
_build_existing_field_assignments,
|
||||
)
|
||||
from skyvern.core.script_generations.generate_workflow_parameters import (
|
||||
generate_workflow_parameters_schema,
|
||||
)
|
||||
from skyvern.forge.forge_app_initializer import start_forge_app
|
||||
from tests.unit.force_stub_app import start_forge_stub_app
|
||||
|
||||
# Load environment variables for real LLM tests
|
||||
load_dotenv()
|
||||
|
||||
# Check if real LLM tests should run (set RUN_LLM_TESTS=1 to enable)
|
||||
SKIP_LLM_TESTS = os.environ.get("RUN_LLM_TESTS", "0") != "1"
|
||||
|
||||
|
||||
class TestBuildExistingFieldAssignments:
|
||||
"""Test the helper function that builds existing field assignments from cached blocks."""
|
||||
|
||||
def test_returns_empty_dict_when_no_cached_blocks(self):
|
||||
"""When there are no cached blocks, should return empty dict."""
|
||||
blocks = [
|
||||
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
|
||||
]
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
|
||||
]
|
||||
}
|
||||
cached_blocks: dict[str, ScriptBlockSource] = {}
|
||||
updated_block_labels: set[str] = set()
|
||||
|
||||
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_returns_empty_dict_when_all_blocks_updated(self):
|
||||
"""When all blocks are in updated_block_labels, should return empty dict."""
|
||||
blocks = [
|
||||
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
|
||||
]
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
|
||||
]
|
||||
}
|
||||
cached_blocks = {
|
||||
"login_block": ScriptBlockSource(
|
||||
label="login_block",
|
||||
code="async def login_block(): ...",
|
||||
run_signature=None,
|
||||
workflow_run_id=None,
|
||||
workflow_run_block_id=None,
|
||||
input_fields=["username"],
|
||||
)
|
||||
}
|
||||
updated_block_labels = {"login_block"} # Block is updated, should not preserve
|
||||
|
||||
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_preserves_field_names_for_unchanged_blocks(self):
|
||||
"""Unchanged blocks with input_fields should have their field names preserved."""
|
||||
blocks = [
|
||||
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
|
||||
]
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
|
||||
{"action_type": "input_text", "text": "pass123", "intention": "Enter password"},
|
||||
]
|
||||
}
|
||||
cached_blocks = {
|
||||
"login_block": ScriptBlockSource(
|
||||
label="login_block",
|
||||
code="async def login_block(): ...",
|
||||
run_signature=None,
|
||||
workflow_run_id=None,
|
||||
workflow_run_block_id=None,
|
||||
input_fields=["user_full_name", "user_password"],
|
||||
)
|
||||
}
|
||||
updated_block_labels: set[str] = set() # No blocks updated
|
||||
|
||||
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
|
||||
|
||||
# Action 1 -> user_full_name, Action 2 -> user_password
|
||||
assert result == {1: "user_full_name", 2: "user_password"}
|
||||
|
||||
def test_preserves_fields_for_multiple_unchanged_blocks(self):
|
||||
"""Multiple unchanged blocks should each have their fields preserved."""
|
||||
blocks = [
|
||||
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
|
||||
{"block_type": "task", "label": "form_block", "task_id": "task_2"},
|
||||
]
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
|
||||
],
|
||||
"task_2": [
|
||||
{"action_type": "input_text", "text": "Acme Inc", "intention": "Enter company"},
|
||||
],
|
||||
}
|
||||
cached_blocks = {
|
||||
"login_block": ScriptBlockSource(
|
||||
label="login_block",
|
||||
code="...",
|
||||
run_signature=None,
|
||||
workflow_run_id=None,
|
||||
workflow_run_block_id=None,
|
||||
input_fields=["username"],
|
||||
),
|
||||
"form_block": ScriptBlockSource(
|
||||
label="form_block",
|
||||
code="...",
|
||||
run_signature=None,
|
||||
workflow_run_id=None,
|
||||
workflow_run_block_id=None,
|
||||
input_fields=["company_name"],
|
||||
),
|
||||
}
|
||||
updated_block_labels: set[str] = set()
|
||||
|
||||
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
|
||||
|
||||
# Action 1 (task_1) -> username, Action 2 (task_2) -> company_name
|
||||
assert result == {1: "username", 2: "company_name"}
|
||||
|
||||
def test_mixed_updated_and_unchanged_blocks(self):
|
||||
"""Only unchanged blocks should have their fields preserved."""
|
||||
blocks = [
|
||||
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
|
||||
{"block_type": "task", "label": "new_block", "task_id": "task_2"},
|
||||
]
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
|
||||
],
|
||||
"task_2": [
|
||||
{"action_type": "input_text", "text": "new value", "intention": "Enter something"},
|
||||
],
|
||||
}
|
||||
cached_blocks = {
|
||||
"login_block": ScriptBlockSource(
|
||||
label="login_block",
|
||||
code="...",
|
||||
run_signature=None,
|
||||
workflow_run_id=None,
|
||||
workflow_run_block_id=None,
|
||||
input_fields=["username"],
|
||||
),
|
||||
# new_block is not in cached_blocks (it's new)
|
||||
}
|
||||
updated_block_labels: set[str] = set()
|
||||
|
||||
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
|
||||
|
||||
# Only action 1 should be preserved, action 2 is from a new block
|
||||
assert result == {1: "username"}
|
||||
|
||||
def test_skips_non_custom_field_actions(self):
|
||||
"""Actions that aren't INPUT_TEXT, UPLOAD_FILE, or SELECT_OPTION should be skipped."""
|
||||
blocks = [
|
||||
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
|
||||
]
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "click", "intention": "Click button"}, # Not a custom field action
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username"},
|
||||
]
|
||||
}
|
||||
cached_blocks = {
|
||||
"login_block": ScriptBlockSource(
|
||||
label="login_block",
|
||||
code="...",
|
||||
run_signature=None,
|
||||
workflow_run_id=None,
|
||||
workflow_run_block_id=None,
|
||||
input_fields=["username"], # Only one input field
|
||||
)
|
||||
}
|
||||
updated_block_labels: set[str] = set()
|
||||
|
||||
result = _build_existing_field_assignments(blocks, actions_by_task, cached_blocks, updated_block_labels)
|
||||
|
||||
# The click action is skipped, so input_text is action 1
|
||||
assert result == {1: "username"}
|
||||
|
||||
|
||||
class TestGenerateWorkflowParametersSchemaWithExistingFields:
|
||||
"""Test that the LLM receives existing field names when generating schema."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_stub_app(self):
|
||||
"""Set up stub app for all tests in this class."""
|
||||
self.stub_app = start_forge_stub_app()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_receives_existing_field_names_in_prompt(self):
|
||||
"""The LLM should receive existing field names to preserve in the prompt."""
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username", "action_id": "act_1"},
|
||||
{"action_type": "input_text", "text": "pass", "intention": "Enter password", "action_id": "act_2"},
|
||||
],
|
||||
"task_2": [
|
||||
{"action_type": "input_text", "text": "new", "intention": "Enter new field", "action_id": "act_3"},
|
||||
],
|
||||
}
|
||||
existing_field_assignments = {
|
||||
1: "preserved_username",
|
||||
2: "preserved_password",
|
||||
# Action 3 has no existing field - needs new name
|
||||
}
|
||||
|
||||
# Mock the LLM response
|
||||
mock_llm_response = {
|
||||
"field_mappings": {
|
||||
"action_index_1": "preserved_username",
|
||||
"action_index_2": "preserved_password",
|
||||
"action_index_3": "new_field_name",
|
||||
},
|
||||
"schema_fields": {
|
||||
"preserved_username": {"type": "str", "description": "Username"},
|
||||
"preserved_password": {"type": "str", "description": "Password"},
|
||||
"new_field_name": {"type": "str", "description": "New field"},
|
||||
},
|
||||
}
|
||||
|
||||
captured_prompt = {}
|
||||
|
||||
async def mock_llm_handler(prompt, prompt_name):
|
||||
captured_prompt["prompt"] = prompt
|
||||
captured_prompt["prompt_name"] = prompt_name
|
||||
return mock_llm_response
|
||||
|
||||
self.stub_app.SCRIPT_GENERATION_LLM_API_HANDLER = AsyncMock(side_effect=mock_llm_handler)
|
||||
|
||||
schema_code, field_mappings = await generate_workflow_parameters_schema(
|
||||
actions_by_task, existing_field_assignments
|
||||
)
|
||||
|
||||
# Verify the prompt contains the existing field names
|
||||
prompt = captured_prompt["prompt"]
|
||||
assert "preserved_username" in prompt
|
||||
assert "preserved_password" in prompt
|
||||
assert "MUST PRESERVE" in prompt or "EXISTING FIELD NAME" in prompt
|
||||
|
||||
# Verify the returned field mappings include preserved names
|
||||
assert field_mappings["task_1:act_1"] == "preserved_username"
|
||||
assert field_mappings["task_1:act_2"] == "preserved_password"
|
||||
assert field_mappings["task_2:act_3"] == "new_field_name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_existing_fields_works_normally(self):
|
||||
"""When there are no existing fields, schema generation should work normally."""
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username", "action_id": "act_1"},
|
||||
],
|
||||
}
|
||||
existing_field_assignments: dict[int, str] = {} # No existing fields
|
||||
|
||||
mock_llm_response = {
|
||||
"field_mappings": {
|
||||
"action_index_1": "username",
|
||||
},
|
||||
"schema_fields": {
|
||||
"username": {"type": "str", "description": "Username field"},
|
||||
},
|
||||
}
|
||||
|
||||
captured_prompt = {}
|
||||
|
||||
async def mock_llm_handler(prompt, prompt_name):
|
||||
captured_prompt["prompt"] = prompt
|
||||
return mock_llm_response
|
||||
|
||||
self.stub_app.SCRIPT_GENERATION_LLM_API_HANDLER = AsyncMock(side_effect=mock_llm_handler)
|
||||
|
||||
schema_code, field_mappings = await generate_workflow_parameters_schema(
|
||||
actions_by_task, existing_field_assignments
|
||||
)
|
||||
|
||||
# Should not contain preservation instructions when no existing fields
|
||||
prompt = captured_prompt["prompt"]
|
||||
# The CRITICAL rule only appears when has_existing_fields is True
|
||||
assert "CRITICAL" not in prompt
|
||||
|
||||
# Should still return valid mappings
|
||||
assert field_mappings["task_1:act_1"] == "username"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_code_includes_preserved_field_names(self):
|
||||
"""The generated schema code should include the preserved field names."""
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{"action_type": "input_text", "text": "john", "intention": "Enter username", "action_id": "act_1"},
|
||||
],
|
||||
}
|
||||
existing_field_assignments = {1: "user_full_name"}
|
||||
|
||||
mock_llm_response = {
|
||||
"field_mappings": {
|
||||
"action_index_1": "user_full_name",
|
||||
},
|
||||
"schema_fields": {
|
||||
"user_full_name": {"type": "str", "description": "The user's full name"},
|
||||
},
|
||||
}
|
||||
|
||||
async def mock_llm_handler(prompt, prompt_name):
|
||||
return mock_llm_response
|
||||
|
||||
self.stub_app.SCRIPT_GENERATION_LLM_API_HANDLER = AsyncMock(side_effect=mock_llm_handler)
|
||||
|
||||
schema_code, field_mappings = await generate_workflow_parameters_schema(
|
||||
actions_by_task, existing_field_assignments
|
||||
)
|
||||
|
||||
# Schema code should include the preserved field name
|
||||
assert "user_full_name" in schema_code
|
||||
assert "str" in schema_code
|
||||
|
||||
|
||||
class TestEndToEndFieldPreservation:
|
||||
"""
|
||||
End-to-end test simulating the real scenario:
|
||||
1. Workflow has a login block with cached code using field names
|
||||
2. User adds a new block
|
||||
3. Schema is regenerated
|
||||
4. Login block's field names should be preserved
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_stub_app(self):
|
||||
"""Set up stub app for all tests in this class."""
|
||||
self.stub_app = start_forge_stub_app()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adding_new_block_preserves_existing_block_field_names(self):
|
||||
"""
|
||||
Simulates: User has workflow with login block, adds a new block.
|
||||
The login block's field names should be preserved in the regenerated schema.
|
||||
"""
|
||||
# Existing blocks (login was already there)
|
||||
blocks = [
|
||||
{"block_type": "login", "label": "login_block", "task_id": "task_1"},
|
||||
{"block_type": "task", "label": "new_block", "task_id": "task_2"}, # Newly added
|
||||
]
|
||||
|
||||
# Actions from both blocks
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{
|
||||
"action_type": "input_text",
|
||||
"text": "john@example.com",
|
||||
"intention": "Enter email",
|
||||
"action_id": "act_1",
|
||||
},
|
||||
{"action_type": "input_text", "text": "secret123", "intention": "Enter password", "action_id": "act_2"},
|
||||
],
|
||||
"task_2": [
|
||||
{
|
||||
"action_type": "input_text",
|
||||
"text": "Acme Inc",
|
||||
"intention": "Enter company name",
|
||||
"action_id": "act_3",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Cached blocks - login_block has existing field names
|
||||
cached_blocks = {
|
||||
"login_block": ScriptBlockSource(
|
||||
label="login_block",
|
||||
code="""
|
||||
@skyvern.cached(cache_key='login_block')
|
||||
async def login_block(page: SkyvernPage, context: RunContext):
|
||||
await page.fill(
|
||||
selector='xpath=//input[@id="email"]',
|
||||
value=context.parameters['user_email'],
|
||||
)
|
||||
await page.fill(
|
||||
selector='xpath=//input[@id="password"]',
|
||||
value=context.parameters['user_password'],
|
||||
)
|
||||
""",
|
||||
run_signature="await skyvern.login(...)",
|
||||
workflow_run_id="wr_123",
|
||||
workflow_run_block_id="wrb_123",
|
||||
input_fields=["user_email", "user_password"], # These must be preserved!
|
||||
),
|
||||
# new_block is not in cached_blocks - it's brand new
|
||||
}
|
||||
|
||||
# Only the new block is "updated" (actually new)
|
||||
updated_block_labels: set[str] = set() # login_block is NOT updated
|
||||
|
||||
# Step 1: Build existing field assignments
|
||||
existing_field_assignments = _build_existing_field_assignments(
|
||||
blocks, actions_by_task, cached_blocks, updated_block_labels
|
||||
)
|
||||
|
||||
# Verify login block fields are identified for preservation
|
||||
assert existing_field_assignments == {
|
||||
1: "user_email",
|
||||
2: "user_password",
|
||||
# Action 3 has no existing field (new block)
|
||||
}
|
||||
|
||||
# Step 2: Mock LLM that respects the preservation instructions
|
||||
mock_llm_response = {
|
||||
"field_mappings": {
|
||||
"action_index_1": "user_email", # Preserved
|
||||
"action_index_2": "user_password", # Preserved
|
||||
"action_index_3": "company_name", # New field for new block
|
||||
},
|
||||
"schema_fields": {
|
||||
"user_email": {"type": "str", "description": "User's email address"},
|
||||
"user_password": {"type": "str", "description": "User's password"},
|
||||
"company_name": {"type": "str", "description": "Company name"},
|
||||
},
|
||||
}
|
||||
|
||||
captured_prompt = {}
|
||||
|
||||
async def mock_llm_handler(prompt, prompt_name):
|
||||
captured_prompt["prompt"] = prompt
|
||||
return mock_llm_response
|
||||
|
||||
self.stub_app.SCRIPT_GENERATION_LLM_API_HANDLER = AsyncMock(side_effect=mock_llm_handler)
|
||||
|
||||
schema_code, field_mappings = await generate_workflow_parameters_schema(
|
||||
actions_by_task, existing_field_assignments
|
||||
)
|
||||
|
||||
# Verify the prompt contains preservation instructions
|
||||
prompt = captured_prompt["prompt"]
|
||||
assert "user_email" in prompt, "Prompt should contain existing field name 'user_email'"
|
||||
assert "user_password" in prompt, "Prompt should contain existing field name 'user_password'"
|
||||
assert "MUST PRESERVE" in prompt or "EXISTING FIELD NAME" in prompt
|
||||
|
||||
# Verify field mappings preserve the original names
|
||||
assert field_mappings["task_1:act_1"] == "user_email", "Login block email field should be preserved"
|
||||
assert field_mappings["task_1:act_2"] == "user_password", "Login block password field should be preserved"
|
||||
assert field_mappings["task_2:act_3"] == "company_name", "New block should get new field name"
|
||||
|
||||
# Verify schema code contains preserved field names
|
||||
assert "user_email" in schema_code
|
||||
assert "user_password" in schema_code
|
||||
assert "company_name" in schema_code
|
||||
|
||||
# The cached login block code references context.parameters['user_email']
|
||||
# and context.parameters['user_password'], which now match the schema!
|
||||
cached_code = cached_blocks["login_block"].code
|
||||
assert "user_email" in cached_code
|
||||
assert "user_password" in cached_code
|
||||
|
||||
|
||||
@pytest.mark.skipif(SKIP_LLM_TESTS, reason="Real LLM test - set RUN_LLM_TESTS=1 to enable")
|
||||
class TestRealLLMFieldPreservation:
|
||||
"""
|
||||
Integration tests that make actual LLM calls to verify field preservation.
|
||||
|
||||
These tests require environment variables to be set (via .env file):
|
||||
- SCRIPT_GENERATION_LLM_KEY or SECONDARY_LLM_KEY
|
||||
- Appropriate API keys for the LLM provider
|
||||
|
||||
Run these tests with:
|
||||
RUN_LLM_TESTS=1 pytest tests/unit/test_workflow_schema_field_preservation.py::TestRealLLMFieldPreservation -v -s
|
||||
|
||||
Note: Skipped by default since they make real LLM calls (costs money).
|
||||
"""
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def setup_real_app(self):
|
||||
"""Set up the real Forge app for LLM calls."""
|
||||
start_forge_app()
|
||||
yield
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_preserves_existing_field_names(self):
|
||||
"""
|
||||
Test that a real LLM preserves field names when instructed to.
|
||||
|
||||
This test sends a prompt with existing field names marked as "MUST PRESERVE"
|
||||
and verifies the LLM returns those exact names in the response.
|
||||
"""
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{
|
||||
"action_type": "input_text",
|
||||
"text": "john.doe@example.com",
|
||||
"intention": "Enter the user's email address for login",
|
||||
"action_id": "act_1",
|
||||
},
|
||||
{
|
||||
"action_type": "input_text",
|
||||
"text": "secretpassword123",
|
||||
"intention": "Enter the user's password",
|
||||
"action_id": "act_2",
|
||||
},
|
||||
],
|
||||
"task_2": [
|
||||
{
|
||||
"action_type": "input_text",
|
||||
"text": "Acme Corporation",
|
||||
"intention": "Enter the company name",
|
||||
"action_id": "act_3",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# These are the existing field names that MUST be preserved
|
||||
# Using unique names to ensure the LLM doesn't accidentally match them
|
||||
existing_field_assignments = {
|
||||
1: "preserved_login_email_xyz",
|
||||
2: "preserved_login_password_abc",
|
||||
# Action 3 has no existing field - LLM should generate a new name
|
||||
}
|
||||
|
||||
schema_code, field_mappings = await generate_workflow_parameters_schema(
|
||||
actions_by_task, existing_field_assignments
|
||||
)
|
||||
|
||||
# Verify the LLM preserved the exact field names we specified
|
||||
assert field_mappings["task_1:act_1"] == "preserved_login_email_xyz", (
|
||||
f"LLM should have preserved 'preserved_login_email_xyz' but got '{field_mappings.get('task_1:act_1')}'"
|
||||
)
|
||||
assert field_mappings["task_1:act_2"] == "preserved_login_password_abc", (
|
||||
f"LLM should have preserved 'preserved_login_password_abc' but got '{field_mappings.get('task_1:act_2')}'"
|
||||
)
|
||||
|
||||
# Verify action 3 got a new field name (not one of the preserved ones)
|
||||
action_3_field = field_mappings.get("task_2:act_3")
|
||||
assert action_3_field is not None, "LLM should have generated a field name for action 3"
|
||||
assert action_3_field not in ["preserved_login_email_xyz", "preserved_login_password_abc"], (
|
||||
f"Action 3 should have a new field name, not a preserved one. Got: {action_3_field}"
|
||||
)
|
||||
|
||||
# Verify the schema code contains the preserved field names
|
||||
assert "preserved_login_email_xyz" in schema_code, "Schema should contain preserved email field"
|
||||
assert "preserved_login_password_abc" in schema_code, "Schema should contain preserved password field"
|
||||
assert action_3_field in schema_code, f"Schema should contain new field '{action_3_field}'"
|
||||
|
||||
print("\n✅ LLM preserved field names correctly!")
|
||||
print(" - Action 1: preserved_login_email_xyz ✓")
|
||||
print(" - Action 2: preserved_login_password_abc ✓")
|
||||
print(f" - Action 3: {action_3_field} (newly generated) ✓")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_generates_all_new_names_when_no_existing_fields(self):
|
||||
"""
|
||||
Test that when there are no existing fields, the LLM generates appropriate new names.
|
||||
This is a baseline test to ensure the LLM call works correctly.
|
||||
"""
|
||||
actions_by_task = {
|
||||
"task_1": [
|
||||
{
|
||||
"action_type": "input_text",
|
||||
"text": "test@example.com",
|
||||
"intention": "Enter email address",
|
||||
"action_id": "act_1",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# No existing field assignments
|
||||
existing_field_assignments: dict[int, str] = {}
|
||||
|
||||
schema_code, field_mappings = await generate_workflow_parameters_schema(
|
||||
actions_by_task, existing_field_assignments
|
||||
)
|
||||
|
||||
# Verify we got a field mapping
|
||||
assert "task_1:act_1" in field_mappings, "Should have a field mapping for the action"
|
||||
field_name = field_mappings["task_1:act_1"]
|
||||
assert field_name, "Field name should not be empty"
|
||||
assert field_name in schema_code, f"Schema should contain the generated field name '{field_name}'"
|
||||
|
||||
print(f"\n✅ LLM generated new field name: {field_name}")
|
||||
152
tests/unit/workflow/test_cache_invalidation.py
Normal file
152
tests/unit/workflow/test_cache_invalidation.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Tests for workflow cache invalidation logic (SKY-7016).
|
||||
|
||||
Verifies that changes to the model field (both at workflow settings level and block level)
|
||||
do not trigger cache invalidation.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockType, TaskBlock
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowDefinition
|
||||
from skyvern.forge.sdk.workflow.service import _get_workflow_definition_core_data
|
||||
|
||||
|
||||
def make_output_parameter(key: str) -> OutputParameter:
|
||||
"""Create a test output parameter."""
|
||||
return OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key=key,
|
||||
description="Test output parameter",
|
||||
output_parameter_id="test-output-id",
|
||||
workflow_id="test-workflow-id",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
modified_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def make_task_block(label: str, model: dict | None = None) -> TaskBlock:
|
||||
"""Create a test task block with optional model configuration."""
|
||||
return TaskBlock(
|
||||
label=label,
|
||||
block_type=BlockType.TASK,
|
||||
output_parameter=make_output_parameter(f"{label}_output"),
|
||||
url="https://example.com",
|
||||
title="Test Task",
|
||||
navigation_goal="Complete the task",
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheInvalidation:
|
||||
"""Tests for the _get_workflow_definition_core_data function."""
|
||||
|
||||
def test_model_field_excluded_from_block_comparison(self) -> None:
|
||||
"""
|
||||
SKY-7016: Verify that block-level model changes don't trigger cache invalidation.
|
||||
|
||||
The model field should be excluded from the comparison data.
|
||||
"""
|
||||
# Create two identical blocks, differing only in the model field
|
||||
block_without_model = make_task_block("task1", model=None)
|
||||
block_with_model = make_task_block("task1", model={"model_name": "gpt-4o"})
|
||||
|
||||
# Create workflow definitions with these blocks
|
||||
definition_without_model = WorkflowDefinition(
|
||||
parameters=[],
|
||||
blocks=[block_without_model],
|
||||
)
|
||||
definition_with_model = WorkflowDefinition(
|
||||
parameters=[],
|
||||
blocks=[block_with_model],
|
||||
)
|
||||
|
||||
# Get the core data used for comparison
|
||||
core_data_without = _get_workflow_definition_core_data(definition_without_model)
|
||||
core_data_with = _get_workflow_definition_core_data(definition_with_model)
|
||||
|
||||
# The core data should be identical (model field excluded)
|
||||
assert core_data_without == core_data_with, (
|
||||
"Model field should be excluded from comparison. "
|
||||
"Changing block-level model should not trigger cache invalidation."
|
||||
)
|
||||
|
||||
def test_model_field_not_in_core_data(self) -> None:
|
||||
"""Verify that the model field is completely removed from the core data."""
|
||||
block = make_task_block("task1", model={"model_name": "claude-3-sonnet"})
|
||||
definition = WorkflowDefinition(
|
||||
parameters=[],
|
||||
blocks=[block],
|
||||
)
|
||||
|
||||
core_data = _get_workflow_definition_core_data(definition)
|
||||
|
||||
# Check that model is not present in any block
|
||||
for block_data in core_data.get("blocks", []):
|
||||
assert "model" not in block_data, "Model field should be removed from block data"
|
||||
|
||||
def test_other_block_changes_still_detected(self) -> None:
|
||||
"""Verify that non-model block changes are still detected."""
|
||||
# Create two blocks with different navigation goals
|
||||
block1 = make_task_block("task1")
|
||||
block1.navigation_goal = "Goal A"
|
||||
|
||||
block2 = make_task_block("task1")
|
||||
block2.navigation_goal = "Goal B"
|
||||
|
||||
definition1 = WorkflowDefinition(parameters=[], blocks=[block1])
|
||||
definition2 = WorkflowDefinition(parameters=[], blocks=[block2])
|
||||
|
||||
core_data1 = _get_workflow_definition_core_data(definition1)
|
||||
core_data2 = _get_workflow_definition_core_data(definition2)
|
||||
|
||||
# These should be different (navigation_goal is not excluded)
|
||||
assert core_data1 != core_data2, "Non-model changes should still be detected for cache invalidation"
|
||||
|
||||
def test_different_models_same_core_data(self) -> None:
|
||||
"""Verify that switching between different models produces same core data."""
|
||||
models = [
|
||||
None,
|
||||
{"model_name": "gpt-4o"},
|
||||
{"model_name": "claude-3-opus"},
|
||||
{"model_name": "gemini-pro", "extra_param": "value"},
|
||||
]
|
||||
|
||||
definitions = []
|
||||
for model in models:
|
||||
block = make_task_block("task1", model=model)
|
||||
definition = WorkflowDefinition(parameters=[], blocks=[block])
|
||||
definitions.append(_get_workflow_definition_core_data(definition))
|
||||
|
||||
# All core data should be identical
|
||||
for i in range(1, len(definitions)):
|
||||
assert definitions[0] == definitions[i], (
|
||||
f"Core data should be identical regardless of model. Definition 0 vs {i} differ."
|
||||
)
|
||||
|
||||
def test_timestamps_excluded_from_comparison(self) -> None:
|
||||
"""Verify that timestamps are properly excluded from comparison."""
|
||||
# Create two blocks with different timestamps
|
||||
block1 = make_task_block("task1")
|
||||
block2 = make_task_block("task1")
|
||||
|
||||
# Simulate different timestamps by recreating output parameters
|
||||
block2.output_parameter = OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key="task1_output",
|
||||
description="Test output parameter",
|
||||
output_parameter_id="different-output-id", # Different ID
|
||||
workflow_id="different-workflow-id", # Different workflow ID
|
||||
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), # Different timestamp
|
||||
modified_at=datetime(2024, 6, 1, tzinfo=timezone.utc), # Different timestamp
|
||||
)
|
||||
|
||||
definition1 = WorkflowDefinition(parameters=[], blocks=[block1])
|
||||
definition2 = WorkflowDefinition(parameters=[], blocks=[block2])
|
||||
|
||||
core_data1 = _get_workflow_definition_core_data(definition1)
|
||||
core_data2 = _get_workflow_definition_core_data(definition2)
|
||||
|
||||
# These should be identical (timestamps and IDs are excluded)
|
||||
assert core_data1 == core_data2, "Timestamps and IDs should be excluded from comparison"
|
||||
232
tests/unit/workflow/test_continue_on_failure_cache.py
Normal file
232
tests/unit/workflow/test_continue_on_failure_cache.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Tests for continue_on_failure behavior with caching.
|
||||
|
||||
Verifies that:
|
||||
1. When a block with continue_on_failure=True fails, it's not cached (existing behavior)
|
||||
2. When a cached block with continue_on_failure=True fails during cached execution,
|
||||
it's marked for regeneration so the next run uses AI execution
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.workflow.models.block import (
|
||||
BlockResult,
|
||||
BlockType,
|
||||
NavigationBlock,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
|
||||
from skyvern.forge.sdk.workflow.service import BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
from skyvern.schemas.workflows import BlockStatus
|
||||
|
||||
|
||||
def _output_parameter(key: str) -> OutputParameter:
|
||||
now = datetime.now(UTC)
|
||||
return OutputParameter(
|
||||
output_parameter_id=f"{key}_id",
|
||||
key=key,
|
||||
workflow_id="wf",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _navigation_block(
|
||||
label: str,
|
||||
continue_on_failure: bool = False,
|
||||
next_block_label: str | None = None,
|
||||
) -> NavigationBlock:
|
||||
return NavigationBlock(
|
||||
url="https://example.com",
|
||||
label=label,
|
||||
title=label,
|
||||
navigation_goal="goal",
|
||||
output_parameter=_output_parameter(f"{label}_output"),
|
||||
next_block_label=next_block_label,
|
||||
continue_on_failure=continue_on_failure,
|
||||
)
|
||||
|
||||
|
||||
class TestContinueOnFailureWithCache:
|
||||
"""Tests for cache invalidation when continue_on_failure blocks fail."""
|
||||
|
||||
def test_navigation_block_is_cacheable(self) -> None:
|
||||
"""Verify NavigationBlock is in the cacheable block types."""
|
||||
assert BlockType.NAVIGATION in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
|
||||
def test_failed_block_without_continue_on_failure_not_added_to_update(self) -> None:
|
||||
"""
|
||||
Test that a failed block without continue_on_failure=True doesn't trigger
|
||||
special cache invalidation logic (it would stop the workflow instead).
|
||||
"""
|
||||
block = _navigation_block("nav1", continue_on_failure=False)
|
||||
blocks_to_update: set[str] = set()
|
||||
script_blocks_by_label = {"nav1": MagicMock()} # Block is cached
|
||||
|
||||
# Simulate failed block result
|
||||
result = BlockResult(
|
||||
success=False,
|
||||
failure_reason="Block failed",
|
||||
output_parameter=block.output_parameter,
|
||||
output_parameter_value=None,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id="wrb-1",
|
||||
)
|
||||
|
||||
# The cache invalidation logic for continue_on_failure
|
||||
# This simulates the condition from service.py
|
||||
should_invalidate = (
|
||||
block.label
|
||||
and block.continue_on_failure
|
||||
and result.status != BlockStatus.completed
|
||||
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
and block.label in script_blocks_by_label
|
||||
)
|
||||
|
||||
if should_invalidate:
|
||||
blocks_to_update.add(block.label)
|
||||
|
||||
# Should NOT be in blocks_to_update because continue_on_failure=False
|
||||
assert block.label not in blocks_to_update
|
||||
|
||||
def test_failed_block_with_continue_on_failure_and_cached_added_to_update(self) -> None:
|
||||
"""
|
||||
Test that a cached block with continue_on_failure=True that fails
|
||||
is added to blocks_to_update for regeneration.
|
||||
"""
|
||||
block = _navigation_block("nav1", continue_on_failure=True)
|
||||
blocks_to_update: set[str] = set()
|
||||
script_blocks_by_label = {"nav1": MagicMock()} # Block is cached
|
||||
|
||||
# Simulate failed block result
|
||||
result = BlockResult(
|
||||
success=False,
|
||||
failure_reason="Block failed",
|
||||
output_parameter=block.output_parameter,
|
||||
output_parameter_value=None,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id="wrb-1",
|
||||
)
|
||||
|
||||
# The cache invalidation logic for continue_on_failure
|
||||
should_invalidate = (
|
||||
block.label
|
||||
and block.continue_on_failure
|
||||
and result.status != BlockStatus.completed
|
||||
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
and block.label in script_blocks_by_label
|
||||
)
|
||||
|
||||
if should_invalidate:
|
||||
blocks_to_update.add(block.label)
|
||||
|
||||
# SHOULD be in blocks_to_update for regeneration
|
||||
assert block.label in blocks_to_update
|
||||
|
||||
def test_failed_uncached_block_with_continue_on_failure_not_added_to_update(self) -> None:
|
||||
"""
|
||||
Test that an uncached block with continue_on_failure=True that fails
|
||||
is NOT added to blocks_to_update (there's nothing to invalidate).
|
||||
"""
|
||||
block = _navigation_block("nav1", continue_on_failure=True)
|
||||
blocks_to_update: set[str] = set()
|
||||
script_blocks_by_label: dict = {} # Block is NOT cached
|
||||
|
||||
# Simulate failed block result
|
||||
result = BlockResult(
|
||||
success=False,
|
||||
failure_reason="Block failed",
|
||||
output_parameter=block.output_parameter,
|
||||
output_parameter_value=None,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id="wrb-1",
|
||||
)
|
||||
|
||||
# The cache invalidation logic for continue_on_failure
|
||||
should_invalidate = (
|
||||
block.label
|
||||
and block.continue_on_failure
|
||||
and result.status != BlockStatus.completed
|
||||
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
and block.label in script_blocks_by_label
|
||||
)
|
||||
|
||||
if should_invalidate:
|
||||
blocks_to_update.add(block.label)
|
||||
|
||||
# Should NOT be in blocks_to_update - nothing to invalidate
|
||||
assert block.label not in blocks_to_update
|
||||
|
||||
def test_successful_block_with_continue_on_failure_not_added_to_update_for_invalidation(self) -> None:
|
||||
"""
|
||||
Test that a successful cached block with continue_on_failure=True
|
||||
is NOT added to blocks_to_update for invalidation.
|
||||
"""
|
||||
block = _navigation_block("nav1", continue_on_failure=True)
|
||||
blocks_to_update: set[str] = set()
|
||||
script_blocks_by_label = {"nav1": MagicMock()} # Block is cached
|
||||
|
||||
# Simulate successful block result
|
||||
result = BlockResult(
|
||||
success=True,
|
||||
failure_reason=None,
|
||||
output_parameter=block.output_parameter,
|
||||
output_parameter_value={"result": "success"},
|
||||
status=BlockStatus.completed,
|
||||
workflow_run_block_id="wrb-1",
|
||||
)
|
||||
|
||||
# The cache invalidation logic for continue_on_failure
|
||||
should_invalidate = (
|
||||
block.label
|
||||
and block.continue_on_failure
|
||||
and result.status != BlockStatus.completed
|
||||
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
and block.label in script_blocks_by_label
|
||||
)
|
||||
|
||||
if should_invalidate:
|
||||
blocks_to_update.add(block.label)
|
||||
|
||||
# Should NOT be in blocks_to_update - block succeeded
|
||||
assert block.label not in blocks_to_update
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status",
|
||||
[BlockStatus.failed, BlockStatus.terminated, BlockStatus.timed_out],
|
||||
)
|
||||
def test_all_failure_statuses_trigger_cache_invalidation(self, status: BlockStatus) -> None:
|
||||
"""
|
||||
Test that all non-completed statuses (failed, terminated, timed_out)
|
||||
trigger cache invalidation when continue_on_failure=True.
|
||||
"""
|
||||
block = _navigation_block("nav1", continue_on_failure=True)
|
||||
blocks_to_update: set[str] = set()
|
||||
script_blocks_by_label = {"nav1": MagicMock()} # Block is cached
|
||||
|
||||
# Simulate block result with the given status
|
||||
result = BlockResult(
|
||||
success=False,
|
||||
failure_reason=f"Block {status.value}",
|
||||
output_parameter=block.output_parameter,
|
||||
output_parameter_value=None,
|
||||
status=status,
|
||||
workflow_run_block_id="wrb-1",
|
||||
)
|
||||
|
||||
# The cache invalidation logic for continue_on_failure
|
||||
should_invalidate = (
|
||||
block.label
|
||||
and block.continue_on_failure
|
||||
and result.status != BlockStatus.completed
|
||||
and block.block_type in BLOCK_TYPES_THAT_SHOULD_BE_CACHED
|
||||
and block.label in script_blocks_by_label
|
||||
)
|
||||
|
||||
if should_invalidate:
|
||||
blocks_to_update.add(block.label)
|
||||
|
||||
# SHOULD be in blocks_to_update for all failure statuses
|
||||
assert block.label in blocks_to_update, f"Status {status} should trigger cache invalidation"
|
||||
249
tests/unit/workflow/test_dag_engine.py
Normal file
249
tests/unit/workflow/test_dag_engine.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.workflow.exceptions import InvalidWorkflowDefinition
|
||||
from skyvern.forge.sdk.workflow.models.block import (
|
||||
BranchCondition,
|
||||
ConditionalBlock,
|
||||
ExtractionBlock,
|
||||
JinjaBranchCriteria,
|
||||
NavigationBlock,
|
||||
PromptBranchCriteria,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
|
||||
from skyvern.forge.sdk.workflow.service import WorkflowService
|
||||
from skyvern.schemas.workflows import BlockStatus
|
||||
|
||||
|
||||
def _output_parameter(key: str) -> OutputParameter:
|
||||
now = datetime.now(UTC)
|
||||
return OutputParameter(
|
||||
output_parameter_id=f"{key}_id",
|
||||
key=key,
|
||||
workflow_id="wf",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _navigation_block(label: str, next_block_label: str | None = None) -> NavigationBlock:
|
||||
return NavigationBlock(
|
||||
url="https://example.com",
|
||||
label=label,
|
||||
title=label,
|
||||
navigation_goal="goal",
|
||||
output_parameter=_output_parameter(f"{label}_output"),
|
||||
next_block_label=next_block_label,
|
||||
)
|
||||
|
||||
|
||||
def _extraction_block(label: str, next_block_label: str | None = None) -> ExtractionBlock:
|
||||
return ExtractionBlock(
|
||||
url="https://example.com",
|
||||
label=label,
|
||||
title=label,
|
||||
data_extraction_goal="extract data",
|
||||
output_parameter=_output_parameter(f"{label}_output"),
|
||||
next_block_label=next_block_label,
|
||||
)
|
||||
|
||||
|
||||
def _conditional_block(
|
||||
label: str, branch_conditions: list[BranchCondition], next_block_label: str | None = None
|
||||
) -> ConditionalBlock:
|
||||
return ConditionalBlock(
|
||||
label=label,
|
||||
output_parameter=_output_parameter(f"{label}_output"),
|
||||
branch_conditions=branch_conditions,
|
||||
next_block_label=next_block_label,
|
||||
)
|
||||
|
||||
|
||||
class DummyContext:
|
||||
def __init__(self, workflow_run_id: str) -> None:
|
||||
self.blocks_metadata: dict[str, dict] = {}
|
||||
self.values: dict[str, object] = {}
|
||||
self.secrets: dict[str, object] = {}
|
||||
self.parameters: dict[str, object] = {}
|
||||
self.workflow_run_outputs: dict[str, object] = {}
|
||||
self.include_secrets_in_templates = False
|
||||
self.workflow_title = "test"
|
||||
self.workflow_id = "wf"
|
||||
self.workflow_permanent_id = "wf-perm"
|
||||
self.workflow_run_id = workflow_run_id
|
||||
|
||||
def update_block_metadata(self, label: str, metadata: dict) -> None:
|
||||
self.blocks_metadata[label] = metadata
|
||||
|
||||
def get_block_metadata(self, label: str | None) -> dict:
|
||||
if label is None:
|
||||
return {}
|
||||
return self.blocks_metadata.get(label, {})
|
||||
|
||||
def mask_secrets_in_data(self, data: object) -> object:
|
||||
"""Mock method - returns data as-is since no secrets in tests."""
|
||||
return data
|
||||
|
||||
async def register_output_parameter_value_post_execution(self, parameter: OutputParameter, value: object) -> None: # noqa: ARG002
|
||||
return None
|
||||
|
||||
def build_workflow_run_summary(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
def test_build_workflow_graph_infers_default_edges() -> None:
|
||||
service = WorkflowService()
|
||||
first = _navigation_block("first")
|
||||
second = _navigation_block("second")
|
||||
|
||||
start_label, label_to_block, default_next_map = service._build_workflow_graph([first, second])
|
||||
|
||||
assert start_label == "first"
|
||||
assert set(label_to_block.keys()) == {"first", "second"}
|
||||
assert default_next_map["first"] == "second"
|
||||
assert default_next_map["second"] is None
|
||||
|
||||
|
||||
def test_build_workflow_graph_rejects_cycles() -> None:
|
||||
service = WorkflowService()
|
||||
first = _navigation_block("first", next_block_label="second")
|
||||
second = _navigation_block("second", next_block_label="first")
|
||||
|
||||
with pytest.raises(InvalidWorkflowDefinition):
|
||||
service._build_workflow_graph([first, second])
|
||||
|
||||
|
||||
def test_build_workflow_graph_requires_single_root() -> None:
|
||||
service = WorkflowService()
|
||||
first = _navigation_block("first")
|
||||
second = _navigation_block("second")
|
||||
|
||||
with pytest.raises(InvalidWorkflowDefinition):
|
||||
service._build_workflow_graph([first, second, _navigation_block("third", next_block_label="second")])
|
||||
|
||||
|
||||
def test_build_workflow_graph_conditional_blocks_no_sequential_defaulting() -> None:
|
||||
"""
|
||||
Test that workflows with conditional blocks do not apply sequential defaulting.
|
||||
|
||||
This prevents cycles when blocks are ordered differently than execution order.
|
||||
For example, if a terminal block appears before branch targets in the blocks array,
|
||||
sequential defaulting would incorrectly create a cycle.
|
||||
"""
|
||||
service = WorkflowService()
|
||||
|
||||
# Simulate a workflow where execution order differs from block array order
|
||||
# Execution: start -> extract -> conditional -> (branch_a OR branch_b) -> terminal
|
||||
# Array order: [start, extract, conditional, terminal, branch_a, branch_b]
|
||||
start = _navigation_block("start", next_block_label="extract")
|
||||
extract = _extraction_block("extract", next_block_label="conditional")
|
||||
conditional = _conditional_block(
|
||||
"conditional",
|
||||
branch_conditions=[
|
||||
BranchCondition(
|
||||
criteria=JinjaBranchCriteria(expression="{{ true }}"), next_block_label="branch_a", is_default=False
|
||||
),
|
||||
BranchCondition(criteria=None, next_block_label="branch_b", is_default=True),
|
||||
],
|
||||
next_block_label="terminal", # This should be ignored for conditional blocks
|
||||
)
|
||||
terminal = _extraction_block("terminal", next_block_label=None) # Terminal block with explicit None
|
||||
branch_a = _navigation_block("branch_a", next_block_label="terminal")
|
||||
branch_b = _navigation_block("branch_b", next_block_label="terminal")
|
||||
|
||||
# Block array has terminal before branch_a and branch_b
|
||||
blocks = [start, extract, conditional, terminal, branch_a, branch_b]
|
||||
|
||||
# This should succeed without creating a cycle
|
||||
start_label, label_to_block, default_next_map = service._build_workflow_graph(blocks)
|
||||
|
||||
assert start_label == "start"
|
||||
assert set(label_to_block.keys()) == {"start", "extract", "conditional", "terminal", "branch_a", "branch_b"}
|
||||
|
||||
# Verify that sequential defaulting was NOT applied
|
||||
# terminal should remain None, not be defaulted to branch_a
|
||||
assert default_next_map["terminal"] is None
|
||||
assert default_next_map["branch_a"] == "terminal"
|
||||
assert default_next_map["branch_b"] == "terminal"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_conditional_block_records_branch_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
output_param = _output_parameter("conditional_output")
|
||||
block = ConditionalBlock(
|
||||
label="cond",
|
||||
output_parameter=output_param,
|
||||
branch_conditions=[
|
||||
BranchCondition(criteria=JinjaBranchCriteria(expression="{{ flag }}"), next_block_label="next"),
|
||||
BranchCondition(is_default=True, next_block_label=None),
|
||||
],
|
||||
)
|
||||
|
||||
ctx = DummyContext(workflow_run_id="run-1")
|
||||
ctx.values["flag"] = True
|
||||
monkeypatch.setattr(app.WORKFLOW_CONTEXT_MANAGER, "get_workflow_run_context", lambda workflow_run_id: ctx)
|
||||
|
||||
app.DATABASE.update_workflow_run_block.reset_mock()
|
||||
app.DATABASE.create_or_update_workflow_run_output_parameter.reset_mock()
|
||||
|
||||
result = await block.execute(
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_block_id="wrb-1",
|
||||
organization_id="org-1",
|
||||
)
|
||||
|
||||
metadata = result.output_parameter_value
|
||||
assert metadata["branch_taken"] == "next"
|
||||
assert metadata["next_block_label"] == "next"
|
||||
assert result.status == BlockStatus.completed
|
||||
assert ctx.blocks_metadata["cond"]["branch_taken"] == "next"
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = app.DATABASE.update_workflow_run_block.call_args
|
||||
assert call_args.kwargs["workflow_run_block_id"] == "wrb-1"
|
||||
assert call_args.kwargs["output"] == metadata
|
||||
assert call_args.kwargs["status"] == BlockStatus.completed
|
||||
assert call_args.kwargs["failure_reason"] is None
|
||||
assert call_args.kwargs["organization_id"] == "org-1"
|
||||
|
||||
# Verify the new execution tracking fields are present
|
||||
assert call_args.kwargs["executed_branch_expression"] == "{{ flag }}"
|
||||
assert call_args.kwargs["executed_branch_result"] is True
|
||||
assert call_args.kwargs["executed_branch_next_block"] == "next"
|
||||
# executed_branch_id should be a UUID string
|
||||
assert isinstance(call_args.kwargs["executed_branch_id"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_branch_uses_batched_evaluation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
output_param = _output_parameter("conditional_output_prompt")
|
||||
prompt_branch = BranchCondition(
|
||||
criteria=PromptBranchCriteria(expression="Check if urgent"), next_block_label="next"
|
||||
)
|
||||
default_branch = BranchCondition(is_default=True, next_block_label=None)
|
||||
block = ConditionalBlock(
|
||||
label="cond_prompt",
|
||||
output_parameter=output_param,
|
||||
branch_conditions=[prompt_branch, default_branch],
|
||||
)
|
||||
|
||||
ctx = DummyContext(workflow_run_id="run-2")
|
||||
monkeypatch.setattr(app.WORKFLOW_CONTEXT_MANAGER, "get_workflow_run_context", lambda workflow_run_id: ctx)
|
||||
# Return tuple: (results, rendered_expressions, extraction_goal, llm_response)
|
||||
prompt_eval_mock = AsyncMock(return_value=([True], ["Check if urgent"], "test prompt", None))
|
||||
monkeypatch.setattr(ConditionalBlock, "_evaluate_prompt_branches", prompt_eval_mock)
|
||||
|
||||
result = await block.execute(
|
||||
workflow_run_id="run-2",
|
||||
workflow_run_block_id="wrb-2",
|
||||
organization_id="org-2",
|
||||
)
|
||||
|
||||
assert result.status == BlockStatus.completed
|
||||
metadata = result.output_parameter_value
|
||||
assert metadata["branch_taken"] == "next"
|
||||
assert metadata["criteria_type"] == "prompt"
|
||||
prompt_eval_mock.assert_awaited_once()
|
||||
232
tests/unit/workflow/test_file_parser_block.py
Normal file
232
tests/unit/workflow/test_file_parser_block.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Tests for FileParserBlock DOCX support.
|
||||
|
||||
Covers file type detection, validation, text extraction (paragraphs + tables),
|
||||
token truncation, and error handling for DOCX files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import docx
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.workflow.exceptions import InvalidFileType
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockType, FileParserBlock
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
|
||||
from skyvern.schemas.workflows import FileType
|
||||
|
||||
|
||||
def _make_output_parameter(key: str) -> OutputParameter:
|
||||
return OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key=key,
|
||||
description="test",
|
||||
output_parameter_id="test-output-id",
|
||||
workflow_id="test-workflow-id",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
modified_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_file_parser_block(file_url: str, file_type: FileType) -> FileParserBlock:
|
||||
return FileParserBlock(
|
||||
label="test_file_parser",
|
||||
block_type=BlockType.FILE_URL_PARSER,
|
||||
output_parameter=_make_output_parameter("test_output"),
|
||||
file_url=file_url,
|
||||
file_type=file_type,
|
||||
)
|
||||
|
||||
|
||||
def _create_docx(
|
||||
path: Path,
|
||||
paragraphs: list[str] | None = None,
|
||||
table_rows: list[list[str]] | None = None,
|
||||
) -> Path:
|
||||
"""Create a DOCX file with optional paragraphs and tables."""
|
||||
doc = docx.Document()
|
||||
if paragraphs:
|
||||
for text in paragraphs:
|
||||
doc.add_paragraph(text)
|
||||
if table_rows:
|
||||
cols = len(table_rows[0])
|
||||
table = doc.add_table(rows=len(table_rows), cols=cols)
|
||||
for i, row_data in enumerate(table_rows):
|
||||
for j, cell_text in enumerate(row_data):
|
||||
table.rows[i].cells[j].text = cell_text
|
||||
doc.save(str(path))
|
||||
return path
|
||||
|
||||
|
||||
class TestDetectFileTypeFromUrl:
|
||||
"""Tests for _detect_file_type_from_url with DOCX extensions."""
|
||||
|
||||
def _detect(self, url: str) -> FileType:
|
||||
block = _make_file_parser_block(url, FileType.CSV)
|
||||
return block._detect_file_type_from_url(url)
|
||||
|
||||
def test_docx_extension(self) -> None:
|
||||
assert self._detect("https://example.com/file.docx") == FileType.DOCX
|
||||
|
||||
def test_doc_extension_raises_error(self) -> None:
|
||||
# Legacy .doc (Word 97-2003) is not supported by python-docx
|
||||
with pytest.raises(InvalidFileType, match="Legacy .doc format"):
|
||||
self._detect("https://example.com/file.doc")
|
||||
|
||||
def test_docx_with_query_params(self) -> None:
|
||||
assert self._detect("https://example.com/file.docx?token=abc&v=1") == FileType.DOCX
|
||||
|
||||
def test_docx_case_insensitive(self) -> None:
|
||||
assert self._detect("https://example.com/file.DOCX") == FileType.DOCX
|
||||
|
||||
def test_other_extensions_unchanged(self) -> None:
|
||||
assert self._detect("https://example.com/file.pdf") == FileType.PDF
|
||||
assert self._detect("https://example.com/file.xlsx") == FileType.EXCEL
|
||||
assert self._detect("https://example.com/file.csv") == FileType.CSV
|
||||
assert self._detect("https://example.com/file.png") == FileType.IMAGE
|
||||
|
||||
|
||||
class TestValidateFileType:
|
||||
"""Tests for validate_file_type with DOCX files."""
|
||||
|
||||
def test_valid_docx(self, tmp_path: Path) -> None:
|
||||
path = _create_docx(tmp_path / "valid.docx", paragraphs=["Hello"])
|
||||
block = _make_file_parser_block("https://example.com/valid.docx", FileType.DOCX)
|
||||
# Should not raise
|
||||
block.validate_file_type("https://example.com/valid.docx", str(path))
|
||||
|
||||
def test_plain_text_with_docx_extension(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "fake.docx"
|
||||
path.write_text("This is plain text, not a DOCX file.")
|
||||
block = _make_file_parser_block("https://example.com/fake.docx", FileType.DOCX)
|
||||
with pytest.raises(InvalidFileType):
|
||||
block.validate_file_type("https://example.com/fake.docx", str(path))
|
||||
|
||||
def test_empty_file(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "empty.docx"
|
||||
path.write_bytes(b"")
|
||||
block = _make_file_parser_block("https://example.com/empty.docx", FileType.DOCX)
|
||||
with pytest.raises(InvalidFileType):
|
||||
block.validate_file_type("https://example.com/empty.docx", str(path))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestParseDocxFile:
|
||||
"""Tests for _parse_docx_file text extraction."""
|
||||
|
||||
async def test_paragraphs_joined_by_newline(self, tmp_path: Path) -> None:
|
||||
path = _create_docx(tmp_path / "paras.docx", paragraphs=["Hello", "World"])
|
||||
block = _make_file_parser_block("https://example.com/paras.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path))
|
||||
assert result == "Hello\nWorld"
|
||||
|
||||
async def test_empty_paragraphs_skipped(self, tmp_path: Path) -> None:
|
||||
path = _create_docx(tmp_path / "blanks.docx", paragraphs=["Hello", "", " ", "World"])
|
||||
block = _make_file_parser_block("https://example.com/blanks.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path))
|
||||
assert result == "Hello\nWorld"
|
||||
|
||||
async def test_table_rows_formatted_with_pipe(self, tmp_path: Path) -> None:
|
||||
path = _create_docx(
|
||||
tmp_path / "table.docx",
|
||||
table_rows=[["Name", "Age"], ["Alice", "30"]],
|
||||
)
|
||||
block = _make_file_parser_block("https://example.com/table.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path))
|
||||
assert result == "Name | Age\nAlice | 30"
|
||||
|
||||
async def test_mixed_paragraphs_and_tables(self, tmp_path: Path) -> None:
|
||||
path = _create_docx(
|
||||
tmp_path / "mixed.docx",
|
||||
paragraphs=["Intro"],
|
||||
table_rows=[["Col1", "Col2"], ["A", "B"]],
|
||||
)
|
||||
block = _make_file_parser_block("https://example.com/mixed.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path))
|
||||
assert result == "Intro\nCol1 | Col2\nA | B"
|
||||
|
||||
async def test_empty_document(self, tmp_path: Path) -> None:
|
||||
path = _create_docx(tmp_path / "empty.docx")
|
||||
block = _make_file_parser_block("https://example.com/empty.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path))
|
||||
assert result == ""
|
||||
|
||||
async def test_empty_table_cells_skipped(self, tmp_path: Path) -> None:
|
||||
path = _create_docx(
|
||||
tmp_path / "sparse.docx",
|
||||
table_rows=[["Name", "", "Age"], ["", "", ""]],
|
||||
)
|
||||
block = _make_file_parser_block("https://example.com/sparse.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path))
|
||||
# First row: "Name" and "Age" (empty cell skipped), second row: all empty -> skipped
|
||||
assert result == "Name | Age"
|
||||
|
||||
async def test_multiple_tables(self, tmp_path: Path) -> None:
|
||||
doc = docx.Document()
|
||||
t1 = doc.add_table(rows=1, cols=2)
|
||||
t1.rows[0].cells[0].text = "T1C1"
|
||||
t1.rows[0].cells[1].text = "T1C2"
|
||||
t2 = doc.add_table(rows=1, cols=2)
|
||||
t2.rows[0].cells[0].text = "T2C1"
|
||||
t2.rows[0].cells[1].text = "T2C2"
|
||||
path = tmp_path / "multi_table.docx"
|
||||
doc.save(str(path))
|
||||
|
||||
block = _make_file_parser_block("https://example.com/multi_table.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path))
|
||||
assert result == "T1C1 | T1C2\nT2C1 | T2C2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestParseDocxFileTokenTruncation:
|
||||
"""Tests for _parse_docx_file token limit enforcement."""
|
||||
|
||||
async def test_paragraphs_truncated(self, tmp_path: Path) -> None:
|
||||
# Create many paragraphs that will exceed a small token limit
|
||||
paragraphs = [f"This is paragraph number {i} with some text content." for i in range(100)]
|
||||
path = _create_docx(tmp_path / "long.docx", paragraphs=paragraphs)
|
||||
block = _make_file_parser_block("https://example.com/long.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path), max_tokens=20)
|
||||
lines = result.split("\n")
|
||||
assert len(lines) < len(paragraphs)
|
||||
# Each included line should be a valid paragraph
|
||||
for line in lines:
|
||||
assert line.startswith("This is paragraph number")
|
||||
|
||||
async def test_tables_truncated(self, tmp_path: Path) -> None:
|
||||
table_rows = [[f"R{i}C1", f"R{i}C2", f"R{i}C3"] for i in range(100)]
|
||||
path = _create_docx(tmp_path / "big_table.docx", table_rows=table_rows)
|
||||
block = _make_file_parser_block("https://example.com/big_table.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path), max_tokens=20)
|
||||
lines = result.split("\n")
|
||||
assert len(lines) < len(table_rows)
|
||||
|
||||
async def test_tables_skipped_when_paragraphs_exhaust_budget(self, tmp_path: Path) -> None:
|
||||
paragraphs = [f"Long paragraph {i} with lots of content to fill tokens." for i in range(100)]
|
||||
table_rows = [["Should", "Not", "Appear"]]
|
||||
path = _create_docx(tmp_path / "para_heavy.docx", paragraphs=paragraphs, table_rows=table_rows)
|
||||
block = _make_file_parser_block("https://example.com/para_heavy.docx", FileType.DOCX)
|
||||
result = await block._parse_docx_file(str(path), max_tokens=20)
|
||||
assert "Should" not in result
|
||||
assert "Not" not in result
|
||||
assert "Appear" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestParseDocxFileErrorHandling:
|
||||
"""Tests for _parse_docx_file error handling."""
|
||||
|
||||
async def test_corrupt_file(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "corrupt.docx"
|
||||
path.write_bytes(b"\x00\x01\x02\x03random bytes")
|
||||
block = _make_file_parser_block("https://example.com/corrupt.docx", FileType.DOCX)
|
||||
with pytest.raises(InvalidFileType):
|
||||
await block._parse_docx_file(str(path))
|
||||
|
||||
async def test_nonexistent_file(self, tmp_path: Path) -> None:
|
||||
block = _make_file_parser_block("https://example.com/missing.docx", FileType.DOCX)
|
||||
with pytest.raises(InvalidFileType):
|
||||
await block._parse_docx_file(str(tmp_path / "nonexistent.docx"))
|
||||
46
tests/unit/workflow/test_http_request_block.py
Normal file
46
tests/unit/workflow/test_http_request_block.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import json
|
||||
|
||||
|
||||
class TestJsonTextParsingEquivalence:
|
||||
"""Prove JSON/text parsing behavior matches aiohttp semantics.
|
||||
|
||||
The HttpRequestBlock parses responses using:
|
||||
try:
|
||||
response_body = json.loads(response_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
response_body = response_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
This should behave equivalently to aiohttp's:
|
||||
try:
|
||||
response_body = await response.json()
|
||||
except (aiohttp.ContentTypeError, Exception):
|
||||
response_body = await response.text()
|
||||
"""
|
||||
|
||||
def _parse_response(self, response_bytes: bytes) -> str | dict | list:
|
||||
try:
|
||||
return json.loads(response_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
return response_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
def test_valid_json_utf8(self) -> None:
|
||||
data = {"key": "value", "number": 42, "unicode": "日本語"}
|
||||
response_bytes = json.dumps(data).encode("utf-8")
|
||||
result = self._parse_response(response_bytes)
|
||||
assert result == data
|
||||
|
||||
def test_invalid_json_returns_text(self) -> None:
|
||||
response_bytes = b"not json, just text"
|
||||
result = self._parse_response(response_bytes)
|
||||
assert result == "not json, just text"
|
||||
|
||||
def test_non_utf8_bytes_handled_gracefully(self) -> None:
|
||||
response_bytes = "café".encode("latin-1") # b'caf\xe9'
|
||||
result = self._parse_response(response_bytes)
|
||||
assert "caf" in result
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_empty_response(self) -> None:
|
||||
response_bytes = b""
|
||||
result = self._parse_response(response_bytes)
|
||||
assert result == ""
|
||||
Reference in New Issue
Block a user