Source code for impact_engine_measure.models.base

"""Base interface for impact models."""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List

import pandas as pd


[docs] @dataclass class ModelResult: """Standardized model result container. All models return this structure, allowing the manager to handle storage uniformly while models remain storage-agnostic. The ``data`` dict must use three standardized keys: - model_params: Input parameters used (formula, intervention_date, etc.) - impact_estimates: The treatment effect measurements - model_summary: Fit diagnostics, sample sizes, configuration echo Attributes ---------- model_type: Identifier for the model that produced this result. data: Primary result data with keys: model_params, impact_estimates, model_summary. metadata: Metadata about the model run (populated by the manager). artifacts: Supplementary DataFrames to persist (e.g., per-product details). Keys are format-agnostic names; the manager prefixes with model_type and appends the file extension. """ model_type: str data: Dict[str, Any] metadata: Dict[str, Any] = field(default_factory=dict) artifacts: Dict[str, pd.DataFrame] = field(default_factory=dict)
[docs] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for storage/serialization. Returns an envelope with model_type, data, and metadata. The ``data`` key contains the model-specific payload (nested, not spread). """ return { "model_type": self.model_type, "data": self.data, "metadata": self.metadata, }
[docs] class ModelInterface(ABC): """Abstract base class for impact models. Defines the unified interface that all impact models must implement. This ensures consistent behavior across different modeling approaches (interrupted time series, causal inference, metrics approximation, etc.). Required methods (must override): - connect: Initialize model with configuration - fit: Fit model to data - validate_params: Validate model-specific parameters before fitting Optional methods (have sensible defaults): - validate_connection: Check if model is ready - validate_data: Check if input data is valid - get_required_columns: Return list of required columns - transform_outbound: Transform data to external format - transform_inbound: Transform results from external format """
[docs] @abstractmethod def connect(self, config: Dict[str, Any]) -> bool: """Initialize model with configuration parameters. Parameters ---------- config : dict Dictionary containing model configuration parameters. Returns ------- bool True if initialization successful, False otherwise. """ pass
[docs] @abstractmethod def fit(self, data: pd.DataFrame, **kwargs) -> Any: """Fit the model to the provided data. Parameters ---------- data : pd.DataFrame DataFrame containing data for model fitting. **kwargs Model-specific parameters (e.g., intervention_date, dependent_variable). Returns ------- Any Model-specific results (Dict, str path, etc.) Raises ------ ValueError If data validation fails or required columns are missing. RuntimeError If model fitting fails. """ pass
[docs] def validate_connection(self) -> bool: """Validate that the model is properly initialized and ready to use. Default implementation returns True. Override for custom validation. Returns ------- bool True if model is ready, False otherwise. """ return True
[docs] def validate_data(self, data: pd.DataFrame) -> bool: """Validate that the input data meets model requirements. Default implementation checks if data is non-empty. Override for custom validation. Parameters ---------- data : pd.DataFrame DataFrame to validate. Returns ------- bool True if data is valid, False otherwise. """ return data is not None and not data.empty
[docs] def get_required_columns(self) -> List[str]: """Get the list of required columns for this model. Default implementation returns empty list. Override if model requires specific columns. Returns ------- list of str Column names that must be present in input data. """ return []
[docs] @abstractmethod def validate_params(self, params: Dict[str, Any]) -> None: """Validate model-specific parameters before fitting. This method is called by ModelsManager before fit() to perform early validation of required parameters. All model implementations MUST override this method to validate their specific parameters. Centralized config validation (process_config) handles known models, but this method ensures custom/user-defined models also validate. Parameters ---------- params : dict Dictionary containing parameters that will be passed to fit(). Typical keys: intervention_date, dependent_variable. Raises ------ ValueError If required parameters are missing or invalid. """ pass
[docs] def get_fit_params(self, params: Dict[str, Any]) -> Dict[str, Any]: """Filter parameters to only those accepted by this adapter's fit(). Called by ModelsManager before fit() to prevent cross-model param pollution. Default returns all params (backward compatible). Built-in adapters override. Parameters ---------- params : dict Full params dict (config PARAMS merged with caller overrides). Returns ------- dict Filtered dict for fit(). """ return dict(params)
[docs] def transform_outbound(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: """Transform impact engine format to model library format. Default implementation is pass-through. Override for models that need data transformation. Parameters ---------- data : pd.DataFrame DataFrame with impact engine standardized format. **kwargs Additional model-specific parameters. Returns ------- dict Dictionary with parameters formatted for the model library. """ return {"data": data, **kwargs}
[docs] def transform_inbound(self, model_results: Any) -> Dict[str, Any]: """Transform model library results to impact engine format. Default implementation returns results as-is (or wrapped in dict). Override for models that need result transformation. Parameters ---------- model_results : Any Raw results from the model library. Returns ------- dict Dictionary with standardized impact analysis results. """ if isinstance(model_results, dict): return model_results return {"results": model_results}