Source code for impact_engine_measure.models.experiment.adapter

"""Experiment Model Adapter - thin wrapper around statsmodels OLS with R-style formulas."""

import logging
from typing import Any, Dict, List

import pandas as pd
import statsmodels.formula.api as smf

from ..base import ModelInterface, ModelResult
from ..factory import MODEL_REGISTRY


[docs] @MODEL_REGISTRY.register_decorator("experiment") class ExperimentAdapter(ModelInterface): """Estimates treatment effects via OLS regression with R-style formulas. Constraints: - formula parameter required in MEASUREMENT.PARAMS - DataFrame must contain all variables referenced in the formula """
[docs] def __init__(self): """Initialize the ExperimentAdapter.""" self.logger = logging.getLogger(__name__) self.is_connected = False self.config = None
[docs] def connect(self, config: Dict[str, Any]) -> bool: """Initialize model with configuration parameters. Config is pre-validated with defaults merged via process_config(). """ formula = config.get("formula") if not formula or not isinstance(formula, str): raise ValueError( "formula is required and must be a string (e.g., 'y ~ treatment + x1'). " "Specify in MEASUREMENT.PARAMS configuration." ) self.config = {"formula": formula} self.is_connected = True return True
[docs] def validate_connection(self) -> bool: """Validate that the model is properly initialized and ready to use.""" if not self.is_connected: return False try: import statsmodels.formula.api # noqa: F401 return True except ImportError: return False
[docs] def validate_params(self, params: Dict[str, Any]) -> None: """Validate experiment-specific parameters. Parameters ---------- params : dict Parameters dict with formula, etc. Raises ------ ValueError If formula is missing. """ formula = params.get("formula") if not formula or not isinstance(formula, str): raise ValueError("formula is required for ExperimentAdapter. Specify in MEASUREMENT.PARAMS configuration.")
_CONFIG_PARAMS = frozenset( { "dependent_variable", "intervention_date", "order", "seasonal_order", "n_strata", "estimand", "treatment_column", "covariate_columns", "formula", "metric_before_column", "metric_after_column", "baseline_column", "RESPONSE", # Nearest neighbour matching params "caliper", "replace", "ratio", "shuffle", "random_state", "n_jobs", } )
[docs] def get_fit_params(self, params: Dict[str, Any]) -> Dict[str, Any]: """Exclude known config keys, pass library kwargs through to statsmodels.""" return {k: v for k, v in params.items() if k not in self._CONFIG_PARAMS}
[docs] def fit(self, data: pd.DataFrame, **kwargs) -> ModelResult: """Fit OLS model using statsmodels formula API and return results. Parameters ---------- data : pd.DataFrame DataFrame containing all variables referenced in the formula. **kwargs Passed through to statsmodels OLS .fit() (e.g., cov_type='HC3' for robust standard errors). Returns ------- ModelResult Standardized result container. Raises ------ ConnectionError If model not connected. RuntimeError If model fitting fails. """ if not self.is_connected: raise ConnectionError("Model not connected. Call connect() first.") if not self.validate_data(data): raise ValueError("Data validation failed. DataFrame must not be empty.") formula = self.config["formula"] try: model = smf.ols(formula, data=data) results = model.fit(**kwargs) # Extract confidence intervals as a nested dict conf_int_df = results.conf_int() conf_int = { var: [float(conf_int_df.loc[var, 0]), float(conf_int_df.loc[var, 1])] for var in conf_int_df.index } impact_estimates = { "params": {k: float(v) for k, v in results.params.items()}, "bse": {k: float(v) for k, v in results.bse.items()}, "tvalues": {k: float(v) for k, v in results.tvalues.items()}, "pvalues": {k: float(v) for k, v in results.pvalues.items()}, "conf_int": conf_int, } model_summary = { "rsquared": float(results.rsquared), "rsquared_adj": float(results.rsquared_adj), "fvalue": float(results.fvalue), "f_pvalue": float(results.f_pvalue), "nobs": int(results.nobs), "df_resid": float(results.df_resid), } self.logger.info( f"Experiment model fit complete: formula='{formula}', " f"nobs={int(results.nobs)}, R²={results.rsquared:.4f}" ) return ModelResult( model_type="experiment", data={ "model_params": {"formula": formula}, "impact_estimates": impact_estimates, "model_summary": model_summary, }, ) except Exception as e: self.logger.error(f"Error fitting ExperimentAdapter: {e}") raise RuntimeError(f"Model fitting failed: {e}") from e
[docs] def get_required_columns(self) -> List[str]: """Get required column names. Returns empty list; statsmodels validates formula variables against the DataFrame natively. """ return []