Source code for impact_engine_measure.models.manager

"""Models manager for coordinating model operations."""

from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Dict, Optional

import pandas as pd

from .base import ModelInterface, ModelResult


[docs] @dataclass class FitOutput: """Structured output from fit_model(). Provides programmatic access to the results path and all artifact paths, so callers do not need to reconstruct file paths from model internals. Attributes ---------- results_path: Full path/URL to impact_results.json. artifact_paths: Mapping of artifact name to full path/URL. model_type: The model type that produced this output. """ results_path: str artifact_paths: Dict[str, str] = field(default_factory=dict) model_type: str = ""
[docs] class ModelsManager: """Central coordinator for model management. Uses dependency injection - the model is passed in via constructor, making the manager easy to test with mock implementations. Note: measurement_config is expected to be pre-validated via process_config(). """
[docs] def __init__( self, measurement_config: Dict[str, Any], model: ModelInterface, ): """Initialize the ModelsManager with injected model. Parameters ---------- measurement_config : dict MEASUREMENT configuration block (pre-validated, with defaults merged). model : ModelInterface The model implementation to use for fitting. """ self.measurement_config = measurement_config self.model = model # Connect the injected model with configuration (PARAMS guaranteed to exist) model_config = measurement_config["PARAMS"] if not self.model.connect(model_config): raise ConnectionError("Failed to connect to model")
[docs] def fit_model( self, data: pd.DataFrame, storage=None, **overrides, ) -> FitOutput: """Fit model using configuration parameters. All PARAMS from config are forwarded as kwargs to validate_params() and fit(). Callers can override any config param via ``**overrides``. Parameters ---------- data : pd.DataFrame DataFrame containing data for model fitting. storage : StorageManager Storage backend for artifacts. **overrides Override any MEASUREMENT.PARAMS value (e.g., intervention_date, dependent_variable). Returns ------- FitOutput FitOutput with paths to all persisted files. """ params = dict(self.measurement_config["PARAMS"]) # Apply caller overrides on top of config values params.update({k: v for k, v in overrides.items() if v is not None}) # Delegate parameter validation to the model self.model.validate_params(params) # Storage backend is required if not storage: raise ValueError("Storage backend is required but not provided") # Filter params to only those accepted by this adapter fit_params = self.model.get_fit_params(params) # Fit model - all models return ModelResult (storage-agnostic) result: ModelResult = self.model.fit( data=data, **fit_params, ) # Populate metadata at the manager level (R5) result.metadata = { "executed_at": datetime.now(timezone.utc).isoformat(), } # Persist artifacts to storage (centralized here, not in models) # Prefix artifact filenames with model_type for namespace hygiene (R2) artifact_paths = {} for name, df in result.artifacts.items(): if not isinstance(df, pd.DataFrame): raise TypeError(f"Artifact '{name}' must be a DataFrame, got {type(df).__name__}") filename = f"{result.model_type}__{name}.parquet" storage.write_parquet(filename, df) artifact_paths[name] = storage.full_path(filename) storage.write_json("impact_results.json", result.to_dict()) results_path = storage.full_path("impact_results.json") return FitOutput( results_path=results_path, artifact_paths=artifact_paths, model_type=result.model_type, )
[docs] def get_current_config(self) -> Optional[Dict[str, Any]]: """Get the currently loaded configuration.""" return self.measurement_config