diff --git a/skyvern/forge/sdk/experimentation/providers.py b/skyvern/forge/sdk/experimentation/providers.py index d78b8ddc..c6ee4bca 100644 --- a/skyvern/forge/sdk/experimentation/providers.py +++ b/skyvern/forge/sdk/experimentation/providers.py @@ -11,7 +11,7 @@ EXPERIMENTATION_CACHE_MAX_SIZE = 100000 # Max entries per cache class BaseExperimentationProvider(ABC): def __init__(self) -> None: - # feature_name -> distinct_id -> result with TTL-based expiration + # Cache with composite key (feature_name, distinct_id) for per-entry TTL expiration self.result_map: TTLCache = TTLCache(maxsize=EXPERIMENTATION_CACHE_MAX_SIZE, ttl=EXPERIMENTATION_CACHE_TTL) self.variant_map: TTLCache = TTLCache(maxsize=EXPERIMENTATION_CACHE_MAX_SIZE, ttl=EXPERIMENTATION_CACHE_TTL) self.payload_map: TTLCache = TTLCache(maxsize=EXPERIMENTATION_CACHE_MAX_SIZE, ttl=EXPERIMENTATION_CACHE_TTL) @@ -23,15 +23,14 @@ class BaseExperimentationProvider(ABC): async def is_feature_enabled_cached( self, feature_name: str, distinct_id: str, properties: dict | None = None ) -> bool: - if feature_name not in self.result_map: - self.result_map[feature_name] = {} - if distinct_id not in self.result_map[feature_name]: + cache_key = (feature_name, distinct_id) + if cache_key not in self.result_map: feature_flag_value = await self.is_feature_enabled(feature_name, distinct_id, properties) - self.result_map[feature_name][distinct_id] = feature_flag_value + self.result_map[cache_key] = feature_flag_value if feature_flag_value: LOG.info("Feature flag is enabled", flag=feature_name, distinct_id=distinct_id) - return self.result_map[feature_name][distinct_id] + return self.result_map[cache_key] @abstractmethod async def get_value(self, feature_name: str, distinct_id: str, properties: dict | None = None) -> str | None: @@ -43,27 +42,25 @@ class BaseExperimentationProvider(ABC): async def get_value_cached(self, feature_name: str, distinct_id: str, properties: dict | None = None) -> str | None: """Get the value of a feature.""" - if feature_name not in self.variant_map: - self.variant_map[feature_name] = {} - if distinct_id not in self.variant_map[feature_name]: + cache_key = (feature_name, distinct_id) + if cache_key not in self.variant_map: variant = await self.get_value(feature_name, distinct_id, properties) - self.variant_map[feature_name][distinct_id] = variant + self.variant_map[cache_key] = variant if variant: LOG.info("Feature is found", flag=feature_name, distinct_id=distinct_id, variant=variant) - return self.variant_map[feature_name][distinct_id] + return self.variant_map[cache_key] async def get_payload_cached( self, feature_name: str, distinct_id: str, properties: dict | None = None ) -> str | None: """Get the payload for a feature flag if it exists.""" - if feature_name not in self.payload_map: - self.payload_map[feature_name] = {} - if distinct_id not in self.payload_map[feature_name]: + cache_key = (feature_name, distinct_id) + if cache_key not in self.payload_map: payload = await self.get_payload(feature_name, distinct_id, properties) - self.payload_map[feature_name][distinct_id] = payload + self.payload_map[cache_key] = payload if payload: LOG.info("Feature payload is found", flag=feature_name, distinct_id=distinct_id, payload=payload) - return self.payload_map[feature_name][distinct_id] + return self.payload_map[cache_key] class NoOpExperimentationProvider(BaseExperimentationProvider):