Files
Dorod-Sky/tests/unit/test_prompt_caching_settings.py

59 lines
2.0 KiB
Python
Raw Normal View History

from unittest.mock import AsyncMock
import pytest
from skyvern.forge import app
from skyvern.forge.agent import (
EXTRACT_ACTION_PROMPT_NAME,
EXTRACT_ACTION_TEMPLATE,
ForgeAgent,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
@pytest.mark.asyncio
async def test_prompt_caching_settings_respect_experiment(monkeypatch):
agent = ForgeAgent()
context = SkyvernContext(run_id="wr_123", organization_id="org_456")
mock_provider = AsyncMock()
mock_provider.is_feature_enabled_cached.return_value = True
monkeypatch.setattr(app, "EXPERIMENTATION_PROVIDER", mock_provider)
try:
LLMAPIHandlerFactory.set_prompt_caching_settings(None)
settings = await agent._get_prompt_caching_settings(context)
assert settings == {
EXTRACT_ACTION_PROMPT_NAME: True,
EXTRACT_ACTION_TEMPLATE: True,
}
mock_provider.is_feature_enabled_cached.assert_awaited_once_with(
"PROMPT_CACHING_OPTIMIZATION",
"wr_123",
properties={"organization_id": "org_456"},
)
# Cached on context; no second provider call
await agent._get_prompt_caching_settings(context)
assert mock_provider.is_feature_enabled_cached.await_count == 1
finally:
LLMAPIHandlerFactory.set_prompt_caching_settings(None)
@pytest.mark.asyncio
async def test_prompt_caching_settings_use_override(monkeypatch):
agent = ForgeAgent()
context = SkyvernContext(run_id="wr_789", organization_id="org_987")
mock_provider = AsyncMock()
monkeypatch.setattr(app, "EXPERIMENTATION_PROVIDER", mock_provider)
try:
LLMAPIHandlerFactory.set_prompt_caching_settings({"extract-actions": True})
settings = await agent._get_prompt_caching_settings(context)
assert settings == {"extract-actions": True}
mock_provider.is_feature_enabled_cached.assert_not_called()
finally:
LLMAPIHandlerFactory.set_prompt_caching_settings(None)