Remove setup.sh in favor of skyvern CLI (#4737)
This commit is contained in:
300
tests/unit/helpers.py
Normal file
300
tests/unit/helpers.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Iterator, Sequence
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from pytest import MonkeyPatch # type: ignore[import-not-found]
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.sdk.api.llm import api_handler_factory
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.api.llm.models import LLMRouterConfig, LLMRouterModelConfig
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
|
||||
|
||||
class FakeLLMResponse:
|
||||
def __init__(self, model: str) -> None:
|
||||
self.model = model
|
||||
self._content = '{"actions": []}'
|
||||
self.choices = [
|
||||
SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
content=self._content,
|
||||
)
|
||||
)
|
||||
]
|
||||
self.usage = SimpleNamespace(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
completion_tokens_details=SimpleNamespace(reasoning_tokens=0),
|
||||
prompt_tokens_details=SimpleNamespace(cached_tokens=0),
|
||||
cache_read_input_tokens=0,
|
||||
)
|
||||
|
||||
def model_dump_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"model": self.model,
|
||||
"choices": [
|
||||
{"message": {"content": self._content}},
|
||||
],
|
||||
},
|
||||
indent=indent,
|
||||
)
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
def info(self, event: str, **kwargs: dict[str, Any]) -> None:
|
||||
self.events.append((event, kwargs))
|
||||
|
||||
def warning(self, *args, **kwargs) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
def exception(self, *args, **kwargs) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
def debug(self, *args, **kwargs) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterTestContext:
|
||||
llm_key: str
|
||||
router_config: LLMRouterConfig
|
||||
logger: DummyLogger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def router_test_context(
|
||||
monkeypatch: MonkeyPatch,
|
||||
*,
|
||||
llm_key: str,
|
||||
primary_group: str,
|
||||
fallback_group: str,
|
||||
routing_strategy: str = "simple-shuffle",
|
||||
) -> Iterator[RouterTestContext]:
|
||||
router_config = LLMRouterConfig(
|
||||
model_name="test-router",
|
||||
required_env_vars=[],
|
||||
supports_vision=False,
|
||||
add_assistant_prefix=False,
|
||||
model_list=[
|
||||
LLMRouterModelConfig(model_name=primary_group, litellm_params={"model": primary_group}),
|
||||
LLMRouterModelConfig(model_name=fallback_group, litellm_params={"model": fallback_group}),
|
||||
],
|
||||
redis_host="localhost",
|
||||
redis_port=6379,
|
||||
redis_password="",
|
||||
main_model_group=primary_group,
|
||||
fallback_model_group=fallback_group,
|
||||
routing_strategy=routing_strategy,
|
||||
num_retries=0,
|
||||
disable_cooldowns=True,
|
||||
temperature=None,
|
||||
)
|
||||
|
||||
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
|
||||
LLMConfigRegistry.register_config(llm_key, router_config)
|
||||
|
||||
logger = DummyLogger()
|
||||
monkeypatch.setattr(api_handler_factory, "LOG", logger)
|
||||
|
||||
async def fake_llm_messages_builder(prompt: str, screenshots, add_assistant_prefix: bool) -> list[dict[str, str]]:
|
||||
return [{"role": "user", "content": prompt}]
|
||||
|
||||
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", fake_llm_messages_builder)
|
||||
monkeypatch.setattr(api_handler_factory.skyvern_context, "current", lambda: None)
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda completion_response: 0.0)
|
||||
|
||||
try:
|
||||
yield RouterTestContext(llm_key=llm_key, router_config=router_config, logger=logger)
|
||||
finally:
|
||||
LLMConfigRegistry._configs.pop(llm_key, None) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def make_organization(now: datetime) -> Organization:
|
||||
return Organization(
|
||||
organization_id="org-123",
|
||||
organization_name="Org",
|
||||
webhook_callback_url=None,
|
||||
max_steps_per_run=None,
|
||||
max_retries_per_step=None,
|
||||
domain=None,
|
||||
bw_organization_id=None,
|
||||
bw_collection_ids=None,
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
def make_task(now: datetime, organization: Organization, **overrides: Any) -> Task:
|
||||
base: dict[str, Any] = {
|
||||
"title": "Task",
|
||||
"url": "https://example.com",
|
||||
"webhook_callback_url": None,
|
||||
"webhook_failure_reason": None,
|
||||
"totp_verification_url": None,
|
||||
"totp_identifier": None,
|
||||
"navigation_goal": "Find the quote",
|
||||
"data_extraction_goal": "Extract the quote",
|
||||
"navigation_payload": None,
|
||||
"error_code_mapping": None,
|
||||
"proxy_location": None,
|
||||
"extracted_information_schema": None,
|
||||
"extra_http_headers": None,
|
||||
"complete_criterion": None,
|
||||
"terminate_criterion": None,
|
||||
"task_type": TaskType.general,
|
||||
"application": None,
|
||||
"include_action_history_in_verification": False,
|
||||
"max_screenshot_scrolls": None,
|
||||
"browser_address": None,
|
||||
"download_timeout": None,
|
||||
"created_at": now,
|
||||
"modified_at": now,
|
||||
"task_id": "task-123",
|
||||
"status": TaskStatus.running,
|
||||
"extracted_information": None,
|
||||
"failure_reason": None,
|
||||
"organization_id": organization.organization_id,
|
||||
"workflow_run_id": None,
|
||||
"workflow_permanent_id": None,
|
||||
"browser_session_id": None,
|
||||
"order": 0,
|
||||
"retry": 0,
|
||||
"max_steps_per_run": None,
|
||||
"errors": [],
|
||||
"model": None,
|
||||
"queued_at": now,
|
||||
"started_at": now,
|
||||
"finished_at": None,
|
||||
}
|
||||
base.update(overrides)
|
||||
return Task(**base)
|
||||
|
||||
|
||||
def make_step(
|
||||
now: datetime,
|
||||
task: Task,
|
||||
*,
|
||||
step_id: str,
|
||||
status: StepStatus,
|
||||
order: int,
|
||||
output,
|
||||
is_last: bool = False,
|
||||
retry_index: int = 0,
|
||||
organization_id: str | None = None,
|
||||
**overrides: Any,
|
||||
) -> Step:
|
||||
base: dict[str, Any] = {
|
||||
"created_at": now,
|
||||
"modified_at": now,
|
||||
"task_id": task.task_id,
|
||||
"step_id": step_id,
|
||||
"status": status,
|
||||
"output": output,
|
||||
"order": order,
|
||||
"is_last": is_last,
|
||||
"retry_index": retry_index,
|
||||
"organization_id": organization_id or task.organization_id,
|
||||
}
|
||||
base.update(overrides)
|
||||
return Step(**base)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelVerificationMocks:
|
||||
create_step: AsyncMock
|
||||
get_task_steps: AsyncMock
|
||||
sleep: AsyncMock
|
||||
check_user_goal_complete: AsyncMock
|
||||
handle_action: AsyncMock
|
||||
create_extract_action: AsyncMock | None
|
||||
speculate_next_step_plan: AsyncMock
|
||||
persist_speculative_metadata: AsyncMock
|
||||
cancel_speculative_step: AsyncMock
|
||||
record_artifacts_after_action: AsyncMock
|
||||
update_step: AsyncMock
|
||||
update_task: AsyncMock
|
||||
|
||||
|
||||
def setup_parallel_verification_mocks(
|
||||
agent: ForgeAgent,
|
||||
*,
|
||||
step: Step,
|
||||
task: Task,
|
||||
monkeypatch: MonkeyPatch,
|
||||
next_step: Step | None,
|
||||
complete_action,
|
||||
handle_action_responses: Sequence[Any],
|
||||
extract_action: Any | None = None,
|
||||
) -> ParallelVerificationMocks:
|
||||
create_step_mock = AsyncMock(return_value=next_step)
|
||||
monkeypatch.setattr(app.DATABASE, "create_step", create_step_mock)
|
||||
|
||||
get_task_steps_mock = AsyncMock(return_value=[step])
|
||||
monkeypatch.setattr(app.DATABASE, "get_task_steps", get_task_steps_mock)
|
||||
|
||||
sleep_mock = AsyncMock(return_value=None)
|
||||
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", sleep_mock)
|
||||
|
||||
check_user_goal_complete_mock = AsyncMock(return_value=complete_action)
|
||||
monkeypatch.setattr(agent, "check_user_goal_complete", check_user_goal_complete_mock)
|
||||
|
||||
handle_action_mock = AsyncMock(side_effect=handle_action_responses)
|
||||
monkeypatch.setattr("skyvern.forge.agent.ActionHandler.handle_action", handle_action_mock)
|
||||
|
||||
speculate_mock = AsyncMock(return_value=None)
|
||||
monkeypatch.setattr(agent, "_speculate_next_step_plan", speculate_mock)
|
||||
|
||||
persist_mock = AsyncMock()
|
||||
monkeypatch.setattr(agent, "_persist_speculative_metadata_for_discarded_plan", persist_mock)
|
||||
|
||||
cancel_mock = AsyncMock()
|
||||
monkeypatch.setattr(agent, "_cancel_speculative_step", cancel_mock)
|
||||
|
||||
record_artifacts_mock = AsyncMock()
|
||||
monkeypatch.setattr(agent, "record_artifacts_after_action", record_artifacts_mock)
|
||||
|
||||
update_step_mock = AsyncMock(return_value=step)
|
||||
monkeypatch.setattr(agent, "update_step", update_step_mock)
|
||||
|
||||
update_task_mock = AsyncMock(return_value=task)
|
||||
monkeypatch.setattr(agent, "update_task", update_task_mock)
|
||||
|
||||
if extract_action is not None:
|
||||
create_extract_action_mock = AsyncMock(return_value=extract_action)
|
||||
monkeypatch.setattr(agent, "create_extract_action", create_extract_action_mock)
|
||||
else:
|
||||
create_extract_action_mock = None
|
||||
|
||||
return ParallelVerificationMocks(
|
||||
create_step=create_step_mock,
|
||||
get_task_steps=get_task_steps_mock,
|
||||
sleep=sleep_mock,
|
||||
check_user_goal_complete=check_user_goal_complete_mock,
|
||||
handle_action=handle_action_mock,
|
||||
create_extract_action=create_extract_action_mock,
|
||||
speculate_next_step_plan=speculate_mock,
|
||||
persist_speculative_metadata=persist_mock,
|
||||
cancel_speculative_step=cancel_mock,
|
||||
record_artifacts_after_action=record_artifacts_mock,
|
||||
update_step=update_step_mock,
|
||||
update_task=update_task_mock,
|
||||
)
|
||||
|
||||
|
||||
def make_browser_state() -> tuple[MagicMock, MagicMock, MagicMock]:
|
||||
return MagicMock(), MagicMock(), MagicMock()
|
||||
Reference in New Issue
Block a user