Source code for online_retail_simulator.simulate.rule_registry

"""
Rule-based simulation registry for custom user-defined simulation functions.

This module provides a registration system that allows users to register their own
rule-based simulation functions for both products and sales metrics.
"""

from typing import Callable, Dict, List

from online_retail_simulator.core import FunctionRegistry


def _load_products_defaults(registry: FunctionRegistry) -> None:
    """Load default products functions."""
    from .products_rule_based import simulate_products_rule_based

    registry.register("simulate_products_rule_based", simulate_products_rule_based)


def _load_metrics_defaults(registry: FunctionRegistry) -> None:
    """Load default metrics functions."""
    from .metrics_rule_based import simulate_metrics_rule_based

    registry.register("simulate_metrics_rule_based", simulate_metrics_rule_based)


# Registry instances
_products_registry = FunctionRegistry(
    name="products",
    required_params={"config"},
    default_loader=_load_products_defaults,
)

_metrics_registry = FunctionRegistry(
    name="metrics",
    required_params={"products", "config"},
    default_loader=_load_metrics_defaults,
)


# Public API functions
[docs] def register_products_function(name: str, func: Callable) -> None: """Register a products generation function.""" _products_registry.register(name, func)
[docs] def register_metrics_function(name: str, func: Callable) -> None: """Register a metrics generation function.""" _metrics_registry.register(name, func)
[docs] def register_simulation_module(module_name: str, prefix: str = "") -> None: """ Register all compatible functions from a module. Functions are automatically detected based on their signatures: - Products functions: must have 'config' parameter - Metrics functions: must have 'products' and 'config' parameters """ # Register products functions _products_registry.register_from_module( module_name, prefix, signature_filter=lambda params: "config" in params and "products" not in params, ) # Register metrics functions _metrics_registry.register_from_module( module_name, prefix, signature_filter=lambda params: "products" in params and "config" in params, )
[docs] def get_simulation_function(func_type: str, name: str) -> Callable: """ Get a registered simulation function. Args: func_type: Type of function ('products' or 'metrics') name: Name of the function Returns: The registered function """ if func_type == "products": return _products_registry.get(name) elif func_type == "metrics": return _metrics_registry.get(name) else: raise ValueError(f"Invalid function type: {func_type}. Must be 'products' or 'metrics'")
[docs] def list_simulation_functions() -> Dict[str, List[str]]: """List all registered simulation functions.""" return { "products": _products_registry.list(), "metrics": _metrics_registry.list(), }
def clear_simulation_registry() -> None: """Clear all registered simulation functions (useful for testing).""" _products_registry.clear() _metrics_registry.clear()