"""Base interface for impact models."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List
import pandas as pd
[docs]
@dataclass
class ModelResult:
"""Standardized model result container.
All models return this structure, allowing the manager to handle
storage uniformly while models remain storage-agnostic.
The ``data`` dict must use three standardized keys:
- model_params: Input parameters used (formula, intervention_date, etc.)
- impact_estimates: The treatment effect measurements
- model_summary: Fit diagnostics, sample sizes, configuration echo
Attributes
----------
model_type: Identifier for the model that produced this result.
data: Primary result data with keys: model_params, impact_estimates, model_summary.
metadata: Metadata about the model run (populated by the manager).
artifacts: Supplementary DataFrames to persist (e.g., per-product details).
Keys are format-agnostic names; the manager prefixes with model_type
and appends the file extension.
"""
model_type: str
data: Dict[str, Any]
metadata: Dict[str, Any] = field(default_factory=dict)
artifacts: Dict[str, pd.DataFrame] = field(default_factory=dict)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage/serialization.
Returns an envelope with model_type, data, and metadata.
The ``data`` key contains the model-specific payload (nested, not spread).
"""
return {
"model_type": self.model_type,
"data": self.data,
"metadata": self.metadata,
}
[docs]
class ModelInterface(ABC):
"""Abstract base class for impact models.
Defines the unified interface that all impact models must implement.
This ensures consistent behavior across different modeling approaches
(interrupted time series, causal inference, metrics approximation, etc.).
Required methods (must override):
- connect: Initialize model with configuration
- fit: Fit model to data
- validate_params: Validate model-specific parameters before fitting
Optional methods (have sensible defaults):
- validate_connection: Check if model is ready
- validate_data: Check if input data is valid
- get_required_columns: Return list of required columns
- transform_outbound: Transform data to external format
- transform_inbound: Transform results from external format
"""
[docs]
@abstractmethod
def connect(self, config: Dict[str, Any]) -> bool:
"""Initialize model with configuration parameters.
Parameters
----------
config : dict
Dictionary containing model configuration parameters.
Returns
-------
bool
True if initialization successful, False otherwise.
"""
pass
[docs]
@abstractmethod
def fit(self, data: pd.DataFrame, **kwargs) -> Any:
"""Fit the model to the provided data.
Parameters
----------
data : pd.DataFrame
DataFrame containing data for model fitting.
**kwargs
Model-specific parameters (e.g., intervention_date, dependent_variable).
Returns
-------
Any
Model-specific results (Dict, str path, etc.)
Raises
------
ValueError
If data validation fails or required columns are missing.
RuntimeError
If model fitting fails.
"""
pass
[docs]
def validate_connection(self) -> bool:
"""Validate that the model is properly initialized and ready to use.
Default implementation returns True. Override for custom validation.
Returns
-------
bool
True if model is ready, False otherwise.
"""
return True
[docs]
def validate_data(self, data: pd.DataFrame) -> bool:
"""Validate that the input data meets model requirements.
Default implementation checks if data is non-empty.
Override for custom validation.
Parameters
----------
data : pd.DataFrame
DataFrame to validate.
Returns
-------
bool
True if data is valid, False otherwise.
"""
return data is not None and not data.empty
[docs]
def get_required_columns(self) -> List[str]:
"""Get the list of required columns for this model.
Default implementation returns empty list.
Override if model requires specific columns.
Returns
-------
list of str
Column names that must be present in input data.
"""
return []
[docs]
@abstractmethod
def validate_params(self, params: Dict[str, Any]) -> None:
"""Validate model-specific parameters before fitting.
This method is called by ModelsManager before fit() to perform
early validation of required parameters. All model implementations
MUST override this method to validate their specific parameters.
Centralized config validation (process_config) handles known models,
but this method ensures custom/user-defined models also validate.
Parameters
----------
params : dict
Dictionary containing parameters that will be passed to fit().
Typical keys: intervention_date, dependent_variable.
Raises
------
ValueError
If required parameters are missing or invalid.
"""
pass
[docs]
def get_fit_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Filter parameters to only those accepted by this adapter's fit().
Called by ModelsManager before fit() to prevent cross-model param pollution.
Default returns all params (backward compatible). Built-in adapters override.
Parameters
----------
params : dict
Full params dict (config PARAMS merged with caller overrides).
Returns
-------
dict
Filtered dict for fit().
"""
return dict(params)