301 lines
9.7 KiB
Python
301 lines
9.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from types import SimpleNamespace
|
|
from typing import Any, Iterator, Sequence
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from pytest import MonkeyPatch # type: ignore[import-not-found]
|
|
|
|
from skyvern.forge import app
|
|
from skyvern.forge.agent import ForgeAgent
|
|
from skyvern.forge.sdk.api.llm import api_handler_factory
|
|
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
|
from skyvern.forge.sdk.api.llm.models import LLMRouterConfig, LLMRouterModelConfig
|
|
from skyvern.forge.sdk.db.enums import TaskType
|
|
from skyvern.forge.sdk.models import Step, StepStatus
|
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
|
|
|
|
|
class FakeLLMResponse:
|
|
def __init__(self, model: str) -> None:
|
|
self.model = model
|
|
self._content = '{"actions": []}'
|
|
self.choices = [
|
|
SimpleNamespace(
|
|
message=SimpleNamespace(
|
|
content=self._content,
|
|
)
|
|
)
|
|
]
|
|
self.usage = SimpleNamespace(
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
completion_tokens_details=SimpleNamespace(reasoning_tokens=0),
|
|
prompt_tokens_details=SimpleNamespace(cached_tokens=0),
|
|
cache_read_input_tokens=0,
|
|
)
|
|
|
|
def model_dump_json(self, indent: int = 2) -> str:
|
|
return json.dumps(
|
|
{
|
|
"model": self.model,
|
|
"choices": [
|
|
{"message": {"content": self._content}},
|
|
],
|
|
},
|
|
indent=indent,
|
|
)
|
|
|
|
|
|
class DummyLogger:
|
|
def __init__(self) -> None:
|
|
self.events: list[tuple[str, dict[str, Any]]] = []
|
|
|
|
def info(self, event: str, **kwargs: dict[str, Any]) -> None:
|
|
self.events.append((event, kwargs))
|
|
|
|
def warning(self, *args, **kwargs) -> None: # pragma: no cover
|
|
pass
|
|
|
|
def exception(self, *args, **kwargs) -> None: # pragma: no cover
|
|
pass
|
|
|
|
def debug(self, *args, **kwargs) -> None: # pragma: no cover
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class RouterTestContext:
|
|
llm_key: str
|
|
router_config: LLMRouterConfig
|
|
logger: DummyLogger
|
|
|
|
|
|
@contextmanager
|
|
def router_test_context(
|
|
monkeypatch: MonkeyPatch,
|
|
*,
|
|
llm_key: str,
|
|
primary_group: str,
|
|
fallback_group: str,
|
|
routing_strategy: str = "simple-shuffle",
|
|
) -> Iterator[RouterTestContext]:
|
|
router_config = LLMRouterConfig(
|
|
model_name="test-router",
|
|
required_env_vars=[],
|
|
supports_vision=False,
|
|
add_assistant_prefix=False,
|
|
model_list=[
|
|
LLMRouterModelConfig(model_name=primary_group, litellm_params={"model": primary_group}),
|
|
LLMRouterModelConfig(model_name=fallback_group, litellm_params={"model": fallback_group}),
|
|
],
|
|
redis_host="localhost",
|
|
redis_port=6379,
|
|
redis_password="",
|
|
main_model_group=primary_group,
|
|
fallback_model_group=fallback_group,
|
|
routing_strategy=routing_strategy,
|
|
num_retries=0,
|
|
disable_cooldowns=True,
|
|
temperature=None,
|
|
)
|
|
|
|
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
|
|
LLMConfigRegistry.register_config(llm_key, router_config)
|
|
|
|
logger = DummyLogger()
|
|
monkeypatch.setattr(api_handler_factory, "LOG", logger)
|
|
|
|
async def fake_llm_messages_builder(prompt: str, screenshots, add_assistant_prefix: bool) -> list[dict[str, str]]:
|
|
return [{"role": "user", "content": prompt}]
|
|
|
|
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", fake_llm_messages_builder)
|
|
monkeypatch.setattr(api_handler_factory.skyvern_context, "current", lambda: None)
|
|
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda completion_response: 0.0)
|
|
|
|
try:
|
|
yield RouterTestContext(llm_key=llm_key, router_config=router_config, logger=logger)
|
|
finally:
|
|
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
|
|
|
|
|
|
def make_organization(now: datetime) -> Organization:
|
|
return Organization(
|
|
organization_id="org-123",
|
|
organization_name="Org",
|
|
webhook_callback_url=None,
|
|
max_steps_per_run=None,
|
|
max_retries_per_step=None,
|
|
domain=None,
|
|
bw_organization_id=None,
|
|
bw_collection_ids=None,
|
|
created_at=now,
|
|
modified_at=now,
|
|
)
|
|
|
|
|
|
def make_task(now: datetime, organization: Organization, **overrides: Any) -> Task:
|
|
base: dict[str, Any] = {
|
|
"title": "Task",
|
|
"url": "https://example.com",
|
|
"webhook_callback_url": None,
|
|
"webhook_failure_reason": None,
|
|
"totp_verification_url": None,
|
|
"totp_identifier": None,
|
|
"navigation_goal": "Find the quote",
|
|
"data_extraction_goal": "Extract the quote",
|
|
"navigation_payload": None,
|
|
"error_code_mapping": None,
|
|
"proxy_location": None,
|
|
"extracted_information_schema": None,
|
|
"extra_http_headers": None,
|
|
"complete_criterion": None,
|
|
"terminate_criterion": None,
|
|
"task_type": TaskType.general,
|
|
"application": None,
|
|
"include_action_history_in_verification": False,
|
|
"max_screenshot_scrolls": None,
|
|
"browser_address": None,
|
|
"download_timeout": None,
|
|
"created_at": now,
|
|
"modified_at": now,
|
|
"task_id": "task-123",
|
|
"status": TaskStatus.running,
|
|
"extracted_information": None,
|
|
"failure_reason": None,
|
|
"organization_id": organization.organization_id,
|
|
"workflow_run_id": None,
|
|
"workflow_permanent_id": None,
|
|
"browser_session_id": None,
|
|
"order": 0,
|
|
"retry": 0,
|
|
"max_steps_per_run": None,
|
|
"errors": [],
|
|
"model": None,
|
|
"queued_at": now,
|
|
"started_at": now,
|
|
"finished_at": None,
|
|
}
|
|
base.update(overrides)
|
|
return Task(**base)
|
|
|
|
|
|
def make_step(
|
|
now: datetime,
|
|
task: Task,
|
|
*,
|
|
step_id: str,
|
|
status: StepStatus,
|
|
order: int,
|
|
output,
|
|
is_last: bool = False,
|
|
retry_index: int = 0,
|
|
organization_id: str | None = None,
|
|
**overrides: Any,
|
|
) -> Step:
|
|
base: dict[str, Any] = {
|
|
"created_at": now,
|
|
"modified_at": now,
|
|
"task_id": task.task_id,
|
|
"step_id": step_id,
|
|
"status": status,
|
|
"output": output,
|
|
"order": order,
|
|
"is_last": is_last,
|
|
"retry_index": retry_index,
|
|
"organization_id": organization_id or task.organization_id,
|
|
}
|
|
base.update(overrides)
|
|
return Step(**base)
|
|
|
|
|
|
@dataclass
|
|
class ParallelVerificationMocks:
|
|
create_step: AsyncMock
|
|
get_task_steps: AsyncMock
|
|
sleep: AsyncMock
|
|
check_user_goal_complete: AsyncMock
|
|
handle_action: AsyncMock
|
|
create_extract_action: AsyncMock | None
|
|
speculate_next_step_plan: AsyncMock
|
|
persist_speculative_metadata: AsyncMock
|
|
cancel_speculative_step: AsyncMock
|
|
record_artifacts_after_action: AsyncMock
|
|
update_step: AsyncMock
|
|
update_task: AsyncMock
|
|
|
|
|
|
def setup_parallel_verification_mocks(
|
|
agent: ForgeAgent,
|
|
*,
|
|
step: Step,
|
|
task: Task,
|
|
monkeypatch: MonkeyPatch,
|
|
next_step: Step | None,
|
|
complete_action,
|
|
handle_action_responses: Sequence[Any],
|
|
extract_action: Any | None = None,
|
|
) -> ParallelVerificationMocks:
|
|
create_step_mock = AsyncMock(return_value=next_step)
|
|
monkeypatch.setattr(app.DATABASE, "create_step", create_step_mock)
|
|
|
|
get_task_steps_mock = AsyncMock(return_value=[step])
|
|
monkeypatch.setattr(app.DATABASE, "get_task_steps", get_task_steps_mock)
|
|
|
|
sleep_mock = AsyncMock(return_value=None)
|
|
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", sleep_mock)
|
|
|
|
check_user_goal_complete_mock = AsyncMock(return_value=complete_action)
|
|
monkeypatch.setattr(agent, "check_user_goal_complete", check_user_goal_complete_mock)
|
|
|
|
handle_action_mock = AsyncMock(side_effect=handle_action_responses)
|
|
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", handle_action_mock)
|
|
|
|
speculate_mock = AsyncMock(return_value=None)
|
|
monkeypatch.setattr(agent, "_speculate_next_step_plan", speculate_mock)
|
|
|
|
persist_mock = AsyncMock()
|
|
monkeypatch.setattr(agent, "_persist_speculative_metadata_for_discarded_plan", persist_mock)
|
|
|
|
cancel_mock = AsyncMock()
|
|
monkeypatch.setattr(agent, "_cancel_speculative_step", cancel_mock)
|
|
|
|
record_artifacts_mock = AsyncMock()
|
|
monkeypatch.setattr(agent, "record_artifacts_after_action", record_artifacts_mock)
|
|
|
|
update_step_mock = AsyncMock(return_value=step)
|
|
monkeypatch.setattr(agent, "update_step", update_step_mock)
|
|
|
|
update_task_mock = AsyncMock(return_value=task)
|
|
monkeypatch.setattr(agent, "update_task", update_task_mock)
|
|
|
|
if extract_action is not None:
|
|
create_extract_action_mock = AsyncMock(return_value=extract_action)
|
|
monkeypatch.setattr(agent, "create_extract_action", create_extract_action_mock)
|
|
else:
|
|
create_extract_action_mock = None
|
|
|
|
return ParallelVerificationMocks(
|
|
create_step=create_step_mock,
|
|
get_task_steps=get_task_steps_mock,
|
|
sleep=sleep_mock,
|
|
check_user_goal_complete=check_user_goal_complete_mock,
|
|
handle_action=handle_action_mock,
|
|
create_extract_action=create_extract_action_mock,
|
|
speculate_next_step_plan=speculate_mock,
|
|
persist_speculative_metadata=persist_mock,
|
|
cancel_speculative_step=cancel_mock,
|
|
record_artifacts_after_action=record_artifacts_mock,
|
|
update_step=update_step_mock,
|
|
update_task=update_task_mock,
|
|
)
|
|
|
|
|
|
def make_browser_state() -> tuple[MagicMock, MagicMock, MagicMock]:
|
|
return MagicMock(), MagicMock(), MagicMock()
|