"""
Impact-based enrichment registry for custom user-defined enrichment functions.
This module provides a registration system that allows users to register their own
impact-based enrichment functions.
"""
from typing import Callable, List
from online_retail_simulator.core import FunctionRegistry
def _load_enrichment_defaults(registry: FunctionRegistry) -> None:
"""Load default enrichment functions."""
from .enrichment_library import (
probability_boost,
product_detail_boost,
quantity_boost,
)
registry.register("quantity_boost", quantity_boost)
registry.register("probability_boost", probability_boost)
registry.register("product_detail_boost", product_detail_boost)
# Registry instance
_enrichment_registry = FunctionRegistry(
name="enrichment",
required_params={"metrics"},
default_loader=_load_enrichment_defaults,
)
# Public API functions
[docs]
def register_enrichment_function(name: str, func: Callable) -> None:
"""Register an enrichment function."""
_enrichment_registry.register(name, func)
[docs]
def register_enrichment_module(module_name: str) -> None:
"""Register all compatible functions from a module."""
_enrichment_registry.register_from_module(module_name)
[docs]
def list_enrichment_functions() -> List[str]:
"""List all registered enrichment functions."""
return _enrichment_registry.list()
[docs]
def clear_enrichment_registry() -> None:
"""Clear all registered enrichment functions."""
_enrichment_registry.clear()
[docs]
def load_effect_function(module_name: str, function_name: str) -> Callable:
"""
Load treatment effect function from registry.
Args:
module_name: Name of module (ignored, kept for backward compatibility)
function_name: Name of function in registry
Returns:
Treatment effect function
"""
return _enrichment_registry.get(function_name)