"""Metrics Approximation Adapter - approximates impact from metric changes.
This model approximates treatment impact by correlating metric changes
(e.g., quality score improvements) with expected outcome changes via
configurable response functions.
"""
import logging
from typing import Any, Dict, List
import pandas as pd
from ..base import ModelInterface, ModelResult
from ..factory import MODEL_REGISTRY
from .response_registry import get_response_function
def _normalize_result(result):
"""Normalize response function output to dict format."""
return result if isinstance(result, dict) else {"impact": result}
[docs]
@MODEL_REGISTRY.register_decorator("metrics_approximation")
class MetricsApproximationAdapter(ModelInterface):
"""Adapter for metrics-based impact approximation that implements ModelInterface.
This model takes enriched products with before/after metric values and baseline
outcomes, then applies a response function to approximate the treatment impact.
Input DataFrame must contain:
- metric_before_column: Pre-intervention metric value
- metric_after_column: Post-intervention metric value
- baseline_column: Baseline sales/revenue
Configuration::
MEASUREMENT:
MODEL: "metrics_approximation"
METRIC_BEFORE_COLUMN: "quality_before"
METRIC_AFTER_COLUMN: "quality_after"
BASELINE_COLUMN: "baseline_sales"
RESPONSE:
FUNCTION: "linear"
PARAMS:
coefficient: 0.5
"""
[docs]
def __init__(self):
"""Initialize the MetricsApproximationAdapter."""
self.logger = logging.getLogger(__name__)
self.is_connected = False
self.config = None
[docs]
def connect(self, config: Dict[str, Any]) -> bool:
"""Initialize model with configuration parameters.
Config is pre-validated with defaults merged via process_config().
Parameters
----------
config : dict
Dictionary containing model configuration:
- metric_before_column: Column name for pre-intervention metric
- metric_after_column: Column name for post-intervention metric
- baseline_column: Column name for baseline outcome
- response: Dict with FUNCTION name and optional PARAMS
Returns
-------
bool
True if initialization successful.
"""
# Config has defaults merged from process_config()
metric_before = config["metric_before_column"]
metric_after = config["metric_after_column"]
baseline = config["baseline_column"]
# Response config has defaults from config_defaults.yaml
response_config = config["RESPONSE"]
if not isinstance(response_config, dict):
raise ValueError("RESPONSE must be a dict with FUNCTION key")
function_name = response_config.get("FUNCTION")
if not function_name:
raise ValueError("RESPONSE must have FUNCTION key - FUNCTION is required")
# Validate that the response function exists
try:
get_response_function(function_name)
except ValueError as e:
raise ValueError(f"Invalid response function: {e}")
self.config = {
"metric_before_column": metric_before,
"metric_after_column": metric_after,
"baseline_column": baseline,
"response_function": function_name,
"response_params": response_config.get("PARAMS", {}),
}
self.is_connected = True
return True
[docs]
def validate_connection(self) -> bool:
"""Validate that the model is properly initialized and ready to use."""
return self.is_connected and self.config is not None
[docs]
def validate_params(self, params: Dict[str, Any]) -> None:
"""Validate metrics approximation parameters.
Metrics approximation has no required fit-time parameters beyond what's
configured in connect(). This implementation satisfies the abstract method
requirement while allowing all params.
Parameters
----------
params : dict
Parameters dict (typically empty for this model).
"""
# No required fit-time params for metrics approximation
pass
[docs]
def get_fit_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Metrics approximation has no fit-time params from config.
All configuration (column names, response function, response params)
is stored in self.config during connect().
"""
return {}
[docs]
def fit(self, data: pd.DataFrame, **kwargs) -> ModelResult:
"""
Fit the metrics approximation model and return results.
For each product, computes:
delta_metric = metric_after - metric_before
approximated_impact = response_function(delta_metric, baseline, row_attributes)
Parameters
----------
data : pd.DataFrame
DataFrame with enriched products (only treated products).
Must contain metric_before, metric_after, and baseline columns.
Additional columns are passed as row_attributes to response function.
**kwargs
Additional parameters passed to response function.
Returns
-------
ModelResult
Standardized result container (storage handled by manager).
Raises
------
ConnectionError
If model not connected.
ValueError
If data validation fails.
"""
if not self.is_connected:
raise ConnectionError("Model not connected. Call connect() first.")
if not self.validate_data(data):
raise ValueError(f"Data validation failed. Required columns: {self.get_required_columns()}")
# Get column names from config
metric_before_col = self.config["metric_before_column"]
metric_after_col = self.config["metric_after_column"]
baseline_col = self.config["baseline_column"]
# Get response function and params
response_fn = get_response_function(self.config["response_function"])
response_params = {**self.config["response_params"], **kwargs}
# Work on a copy to avoid modifying input data
df = data.copy()
# Filter rows with missing values in required columns
required_columns = [metric_before_col, metric_after_col, baseline_col]
df, filtered_ids_df = self._filter_missing_values(df, required_columns)
artifacts = {}
if not filtered_ids_df.empty:
artifacts["filtered_products"] = filtered_ids_df
if df.empty:
return self._empty_result()
# Vectorize delta computation
df["_delta_metric"] = df[metric_after_col] - df[metric_before_col]
# Use apply() instead of iterrows() for better performance
# Pass row_attributes to enable attribute-based conditioning in response functions
def compute_impact(row):
result = response_fn(
row["_delta_metric"],
row[baseline_col],
row_attributes=row.to_dict(),
**response_params,
)
return _normalize_result(result)
# Expand dict results into multiple DataFrame columns
df["_result"] = df.apply(compute_impact, axis=1)
result_df = pd.DataFrame(df["_result"].tolist(), index=df.index)
impact_keys = result_df.columns.tolist()
df = pd.concat([df, result_df], axis=1)
# Build per-product results with dynamic keys
def build_product_result(row):
result = {
"product_id": row.get("product_id", str(row.name)),
"delta_metric": round(row["_delta_metric"], 4),
"baseline_outcome": round(row[baseline_col], 2),
}
for key in impact_keys:
result[key] = round(row[key], 2)
return result
per_product_df = pd.DataFrame(df.apply(build_product_result, axis=1).tolist())
artifacts["product_level_impacts"] = per_product_df
# Compute aggregates from vectorized columns
n_products = len(df)
# Build aggregate estimates with dynamic keys
impact_estimates = {key: round(df[key].sum(), 2) for key in impact_keys}
# n_products is in model_summary, not impact_estimates
self.logger.info(f"Metrics approximation complete: {n_products} products, impact_estimates={impact_estimates}")
return ModelResult(
model_type="metrics_approximation",
data={
"model_params": {
"response_function": self.config["response_function"],
"response_params": self.config["response_params"],
},
"impact_estimates": impact_estimates,
"model_summary": {
"n_products": n_products,
},
},
artifacts=artifacts,
)
[docs]
def validate_data(self, data: pd.DataFrame) -> bool:
"""Validate that the input data meets model requirements.
Parameters
----------
data : pd.DataFrame
DataFrame to validate.
Returns
-------
bool
True if data is valid, False otherwise.
"""
if data is None or data.empty:
self.logger.warning("Data is empty")
return False
required_cols = self.get_required_columns()
missing_cols = [col for col in required_cols if col not in data.columns]
if missing_cols:
self.logger.warning(f"Missing required columns: {missing_cols}")
return False
return True
[docs]
def get_required_columns(self) -> List[str]:
"""Get the list of required columns for this model.
Returns
-------
list of str
Column names that must be present in input data.
"""
if not self.config:
return ["quality_before", "quality_after", "baseline_sales"]
return [
self.config["metric_before_column"],
self.config["metric_after_column"],
self.config["baseline_column"],
]
def _filter_missing_values(
self,
df: pd.DataFrame,
required_columns: List[str],
) -> tuple:
"""Filter rows with missing values in required columns and log them.
Parameters
----------
df : pd.DataFrame
DataFrame to filter.
required_columns : list of str
Columns to check for NaN/None values.
Returns
-------
tuple
Tuple of (filtered DataFrame, DataFrame of filtered product IDs).
The second DataFrame is empty when no rows were filtered.
"""
mask = df[required_columns].notna().all(axis=1)
filtered_ids = df.loc[~mask, "product_id"].tolist()
if filtered_ids:
self.logger.warning(
f"Filtered {len(filtered_ids)} rows with missing values in columns "
f"{required_columns}. See filtered_products.parquet for details."
)
filtered_ids_df = pd.DataFrame({"product_id": filtered_ids})
else:
filtered_ids_df = pd.DataFrame()
return df[mask].copy(), filtered_ids_df
def _empty_result(self) -> ModelResult:
"""Return zero-impact result when no valid data remains after filtering.
This enables pipeline testing with mock/incomplete data without errors.
"""
return ModelResult(
model_type="metrics_approximation",
data={
"model_params": {
"response_function": self.config["response_function"],
"response_params": self.config["response_params"],
},
"impact_estimates": {
"impact": 0.0,
},
"model_summary": {
"n_products": 0,
},
},
)