Files
Dorod-Sky/tests/unit/helpers.py
2026-02-12 20:43:27 -08:00

301 lines
9.7 KiB
Python

from __future__ import annotations
import json
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from types import SimpleNamespace
from typing import Any, Iterator, Sequence
from unittest.mock import AsyncMock, MagicMock
from pytest import MonkeyPatch # type: ignore[import-not-found]
from skyvern.forge import app
from skyvern.forge.agent import ForgeAgent
from skyvern.forge.sdk.api.llm import api_handler_factory
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.models import LLMRouterConfig, LLMRouterModelConfig
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
class FakeLLMResponse:
def __init__(self, model: str) -> None:
self.model = model
self._content = '{"actions": []}'
self.choices = [
SimpleNamespace(
message=SimpleNamespace(
content=self._content,
)
)
]
self.usage = SimpleNamespace(
prompt_tokens=0,
completion_tokens=0,
completion_tokens_details=SimpleNamespace(reasoning_tokens=0),
prompt_tokens_details=SimpleNamespace(cached_tokens=0),
cache_read_input_tokens=0,
)
def model_dump_json(self, indent: int = 2) -> str:
return json.dumps(
{
"model": self.model,
"choices": [
{"message": {"content": self._content}},
],
},
indent=indent,
)
class DummyLogger:
def __init__(self) -> None:
self.events: list[tuple[str, dict[str, Any]]] = []
def info(self, event: str, **kwargs: dict[str, Any]) -> None:
self.events.append((event, kwargs))
def warning(self, *args, **kwargs) -> None: # pragma: no cover
pass
def exception(self, *args, **kwargs) -> None: # pragma: no cover
pass
def debug(self, *args, **kwargs) -> None: # pragma: no cover
pass
@dataclass
class RouterTestContext:
llm_key: str
router_config: LLMRouterConfig
logger: DummyLogger
@contextmanager
def router_test_context(
monkeypatch: MonkeyPatch,
*,
llm_key: str,
primary_group: str,
fallback_group: str,
routing_strategy: str = "simple-shuffle",
) -> Iterator[RouterTestContext]:
router_config = LLMRouterConfig(
model_name="test-router",
required_env_vars=[],
supports_vision=False,
add_assistant_prefix=False,
model_list=[
LLMRouterModelConfig(model_name=primary_group, litellm_params={"model": primary_group}),
LLMRouterModelConfig(model_name=fallback_group, litellm_params={"model": fallback_group}),
],
redis_host="localhost",
redis_port=6379,
redis_password="",
main_model_group=primary_group,
fallback_model_group=fallback_group,
routing_strategy=routing_strategy,
num_retries=0,
disable_cooldowns=True,
temperature=None,
)
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
LLMConfigRegistry.register_config(llm_key, router_config)
logger = DummyLogger()
monkeypatch.setattr(api_handler_factory, "LOG", logger)
async def fake_llm_messages_builder(prompt: str, screenshots, add_assistant_prefix: bool) -> list[dict[str, str]]:
return [{"role": "user", "content": prompt}]
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", fake_llm_messages_builder)
monkeypatch.setattr(api_handler_factory.skyvern_context, "current", lambda: None)
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda completion_response: 0.0)
try:
yield RouterTestContext(llm_key=llm_key, router_config=router_config, logger=logger)
finally:
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
def make_organization(now: datetime) -> Organization:
return Organization(
organization_id="org-123",
organization_name="Org",
webhook_callback_url=None,
max_steps_per_run=None,
max_retries_per_step=None,
domain=None,
bw_organization_id=None,
bw_collection_ids=None,
created_at=now,
modified_at=now,
)
def make_task(now: datetime, organization: Organization, **overrides: Any) -> Task:
base: dict[str, Any] = {
"title": "Task",
"url": "https://example.com",
"webhook_callback_url": None,
"webhook_failure_reason": None,
"totp_verification_url": None,
"totp_identifier": None,
"navigation_goal": "Find the quote",
"data_extraction_goal": "Extract the quote",
"navigation_payload": None,
"error_code_mapping": None,
"proxy_location": None,
"extracted_information_schema": None,
"extra_http_headers": None,
"complete_criterion": None,
"terminate_criterion": None,
"task_type": TaskType.general,
"application": None,
"include_action_history_in_verification": False,
"max_screenshot_scrolls": None,
"browser_address": None,
"download_timeout": None,
"created_at": now,
"modified_at": now,
"task_id": "task-123",
"status": TaskStatus.running,
"extracted_information": None,
"failure_reason": None,
"organization_id": organization.organization_id,
"workflow_run_id": None,
"workflow_permanent_id": None,
"browser_session_id": None,
"order": 0,
"retry": 0,
"max_steps_per_run": None,
"errors": [],
"model": None,
"queued_at": now,
"started_at": now,
"finished_at": None,
}
base.update(overrides)
return Task(**base)
def make_step(
now: datetime,
task: Task,
*,
step_id: str,
status: StepStatus,
order: int,
output,
is_last: bool = False,
retry_index: int = 0,
organization_id: str | None = None,
**overrides: Any,
) -> Step:
base: dict[str, Any] = {
"created_at": now,
"modified_at": now,
"task_id": task.task_id,
"step_id": step_id,
"status": status,
"output": output,
"order": order,
"is_last": is_last,
"retry_index": retry_index,
"organization_id": organization_id or task.organization_id,
}
base.update(overrides)
return Step(**base)
@dataclass
class ParallelVerificationMocks:
create_step: AsyncMock
get_task_steps: AsyncMock
sleep: AsyncMock
check_user_goal_complete: AsyncMock
handle_action: AsyncMock
create_extract_action: AsyncMock | None
speculate_next_step_plan: AsyncMock
persist_speculative_metadata: AsyncMock
cancel_speculative_step: AsyncMock
record_artifacts_after_action: AsyncMock
update_step: AsyncMock
update_task: AsyncMock
def setup_parallel_verification_mocks(
agent: ForgeAgent,
*,
step: Step,
task: Task,
monkeypatch: MonkeyPatch,
next_step: Step | None,
complete_action,
handle_action_responses: Sequence[Any],
extract_action: Any | None = None,
) -> ParallelVerificationMocks:
create_step_mock = AsyncMock(return_value=next_step)
monkeypatch.setattr(app.DATABASE, "create_step", create_step_mock)
get_task_steps_mock = AsyncMock(return_value=[step])
monkeypatch.setattr(app.DATABASE, "get_task_steps", get_task_steps_mock)
sleep_mock = AsyncMock(return_value=None)
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", sleep_mock)
check_user_goal_complete_mock = AsyncMock(return_value=complete_action)
monkeypatch.setattr(agent, "check_user_goal_complete", check_user_goal_complete_mock)
handle_action_mock = AsyncMock(side_effect=handle_action_responses)
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", handle_action_mock)
speculate_mock = AsyncMock(return_value=None)
monkeypatch.setattr(agent, "_speculate_next_step_plan", speculate_mock)
persist_mock = AsyncMock()
monkeypatch.setattr(agent, "_persist_speculative_metadata_for_discarded_plan", persist_mock)
cancel_mock = AsyncMock()
monkeypatch.setattr(agent, "_cancel_speculative_step", cancel_mock)
record_artifacts_mock = AsyncMock()
monkeypatch.setattr(agent, "record_artifacts_after_action", record_artifacts_mock)
update_step_mock = AsyncMock(return_value=step)
monkeypatch.setattr(agent, "update_step", update_step_mock)
update_task_mock = AsyncMock(return_value=task)
monkeypatch.setattr(agent, "update_task", update_task_mock)
if extract_action is not None:
create_extract_action_mock = AsyncMock(return_value=extract_action)
monkeypatch.setattr(agent, "create_extract_action", create_extract_action_mock)
else:
create_extract_action_mock = None
return ParallelVerificationMocks(
create_step=create_step_mock,
get_task_steps=get_task_steps_mock,
sleep=sleep_mock,
check_user_goal_complete=check_user_goal_complete_mock,
handle_action=handle_action_mock,
create_extract_action=create_extract_action_mock,
speculate_next_step_plan=speculate_mock,
persist_speculative_metadata=persist_mock,
cancel_speculative_step=cancel_mock,
record_artifacts_after_action=record_artifacts_mock,
update_step=update_step_mock,
update_task=update_task_mock,
)
def make_browser_state() -> tuple[MagicMock, MagicMock, MagicMock]:
return MagicMock(), MagicMock(), MagicMock()