{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Neighbour Matching\n", "\n", "This notebook demonstrates **nearest neighbour matching** impact estimation via [causalml](https://causalml.readthedocs.io/) [`NearestNeighborMatch`](https://causalml.readthedocs.io/en/latest/methodology.html#matching).\n", "\n", "The model matches treated and control units on observed covariates (price), then computes ATT (Average Treatment Effect on the Treated), ATC (on the Controls), and ATE (overall) from mean outcome differences in the matched sample.\n", "\n", "## Workflow overview\n", "\n", "1. User provides `products.csv`\n", "2. User configures `DATA.ENRICHMENT` for treatment assignment\n", "3. User calls `measure_impact(config.yaml)`\n", "4. Engine handles everything internally (adapter, enrichment, model)" ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Initial setup" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import pandas as pd\n", "from impact_engine_measure import measure_impact, load_results\n", "from impact_engine_measure.core.validation import load_config\n", "from impact_engine_measure.models.factory import get_model_adapter\n", "from online_retail_simulator import enrich, simulate" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "## Step 1 — Product Catalog\n", "\n", "In production, this would be your actual product catalog." ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "output_path = Path(\"output/demo_nearest_neighbour_matching\")\n", "output_path.mkdir(parents=True, exist_ok=True)\n", "\n", "catalog_job = simulate(\"configs/demo_nearest_neighbour_matching_catalog.yaml\", job_id=\"catalog\")\n", "products = catalog_job.load_df(\"products\")\n", "\n", "print(f\"Generated {len(products)} products\")\n", "print(f\"Products catalog: {catalog_job.get_store().full_path('products.csv')}\")\n", "products.head()" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "## Step 2 — Engine configuration\n", "\n", "Configure the engine with the following sections.\n", "- `ENRICHMENT` — Treatment assignment via quality boost (50/50 split)\n", "- `MODEL` — `nearest_neighbour_matching` matching on `price`\n", "\n", "Matching parameters.\n", "- `caliper: 0.2` — maximum allowed distance between matched pairs\n", "- `replace: true` — control units can be reused across matches\n", "- `ratio: 1` — one-to-one matching" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "config_path = \"configs/demo_nearest_neighbour_matching.yaml\"" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "## Step 3 — Impact evaluation\n", "\n", "A single call to `measure_impact()` handles everything.\n", "- Engine creates `CatalogSimulatorAdapter`\n", "- Adapter simulates metrics (single-day, cross-sectional)\n", "- Adapter applies enrichment (treatment assignment + revenue boost)\n", "- `NearestNeighbourMatchingAdapter` matches on price, computes ATT/ATC/ATE" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "job_info = measure_impact(config_path, str(output_path), job_id=\"results\")\n", "print(f\"Job ID: {job_info.job_id}\")" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "## Step 4 — Review results" ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [], "source": [ "result = load_results(job_info)\n", "\n", "data = result.impact_results[\"data\"]\n", "estimates = data[\"impact_estimates\"]\n", "summary = data[\"model_summary\"]\n", "\n", "print(\"=\" * 60)\n", "print(\"NEAREST NEIGHBOUR MATCHING RESULTS\")\n", "print(\"=\" * 60)\n", "\n", "print(f\"\\nModel Type: {result.model_type}\")\n", "\n", "print(\"\\n--- Impact Estimates ---\")\n", "print(f\"ATT (Avg Treatment on Treated): {estimates['att']:.4f} (SE: {estimates['att_se']:.4f})\")\n", "print(f\"ATC (Avg Treatment on Controls): {estimates['atc']:.4f} (SE: {estimates['atc_se']:.4f})\")\n", "print(f\"ATE (Avg Treatment Effect): {estimates['ate']:.4f}\")\n", "\n", "print(\"\\n--- Matching Summary ---\")\n", "print(f\"Observations: {summary['n_observations']}\")\n", "print(f\"Treated: {summary['n_treated']}\")\n", "print(f\"Control: {summary['n_control']}\")\n", "print(f\"Matched (ATT): {summary['n_matched_att']}\")\n", "print(f\"Matched (ATC): {summary['n_matched_atc']}\")\n", "print(f\"Caliper: {summary['caliper']}\")\n", "print(f\"Replace: {summary['replace']}\")\n", "print(f\"Ratio: {summary['ratio']}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "# Covariate balance artifacts\n", "balance_before = result.model_artifacts[\"balance_before\"]\n", "balance_after = result.model_artifacts[\"balance_after\"]\n", "\n", "print(\"--- Covariate Balance Before Matching ---\")\n", "print(balance_before.to_string())\n", "print(\"\\n--- Covariate Balance After Matching ---\")\n", "print(balance_after.to_string())\n", "\n", "print(\"\\n\" + \"=\" * 60)\n", "print(\"Demo Complete!\")\n", "print(\"=\" * 60)" ] }, { "cell_type": "markdown", "id": "12", "metadata": {}, "source": [ "## Step 5 — Model validation\n", "\n", "Compare the model's ATT estimate against the **true causal effect** computed from counterfactual vs factual data." ] }, { "cell_type": "code", "execution_count": null, "id": "13", "metadata": {}, "outputs": [], "source": [ "def calculate_true_effect(\n", " baseline_metrics: pd.DataFrame,\n", " enriched_metrics: pd.DataFrame,\n", ") -> dict:\n", " \"\"\"Calculate TRUE ATT by comparing per-product revenue for treated products.\"\"\"\n", " treated_ids = enriched_metrics[enriched_metrics[\"enriched\"]][\"product_id\"].unique()\n", "\n", " enriched_treated = enriched_metrics[enriched_metrics[\"product_id\"].isin(treated_ids)]\n", " baseline_treated = baseline_metrics[baseline_metrics[\"product_id\"].isin(treated_ids)]\n", "\n", " enriched_mean = enriched_treated.groupby(\"product_id\")[\"revenue\"].mean().mean()\n", " baseline_mean = baseline_treated.groupby(\"product_id\")[\"revenue\"].mean().mean()\n", " treatment_effect = enriched_mean - baseline_mean\n", "\n", " return {\n", " \"enriched_mean\": float(enriched_mean),\n", " \"baseline_mean\": float(baseline_mean),\n", " \"treatment_effect\": float(treatment_effect),\n", " }" ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "baseline_metrics = catalog_job.load_df(\"metrics\").rename(columns={\"product_identifier\": \"product_id\"})\n", "\n", "enrich(\"configs/demo_nearest_neighbour_matching_enrichment.yaml\", catalog_job)\n", "enriched_metrics = catalog_job.load_df(\"enriched\").rename(columns={\"product_identifier\": \"product_id\"})\n", "\n", "print(f\"Baseline records: {len(baseline_metrics)}\")\n", "print(f\"Enriched records: {len(enriched_metrics)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "15", "metadata": {}, "outputs": [], "source": [ "true_effect = calculate_true_effect(baseline_metrics, enriched_metrics)\n", "\n", "true_te = true_effect[\"treatment_effect\"]\n", "model_te = estimates[\"att\"]\n", "\n", "if true_te != 0:\n", " recovery_accuracy = (1 - abs(1 - model_te / true_te)) * 100\n", "else:\n", " recovery_accuracy = 100 if model_te == 0 else 0\n", "\n", "print(\"=\" * 60)\n", "print(\"TRUTH RECOVERY VALIDATION\")\n", "print(\"=\" * 60)\n", "print(f\"True treatment effect: {true_te:.4f}\")\n", "print(f\"Model ATT estimate: {model_te:.4f}\")\n", "print(f\"Recovery accuracy: {max(0, recovery_accuracy):.1f}%\")\n", "print(\"=\" * 60)" ] }, { "cell_type": "markdown", "id": "16", "metadata": {}, "source": [ "### Convergence analysis\n", "\n", "How does the estimate converge to the true effect as sample size increases?" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "sample_sizes = [20, 50, 100, 200, 300, 500, 1500]\n", "estimates_list = []\n", "truth_list = []\n", "\n", "parsed = load_config(config_path)\n", "measurement_config = parsed[\"MEASUREMENT\"]\n", "all_product_ids = enriched_metrics[\"product_id\"].unique()\n", "\n", "for n in sample_sizes:\n", " subset_ids = all_product_ids[:n]\n", " enriched_sub = enriched_metrics[enriched_metrics[\"product_id\"].isin(subset_ids)]\n", " baseline_sub = baseline_metrics[baseline_metrics[\"product_id\"].isin(subset_ids)]\n", "\n", " true = calculate_true_effect(baseline_sub, enriched_sub)\n", " truth_list.append(true[\"treatment_effect\"])\n", "\n", " model = get_model_adapter(\"nearest_neighbour_matching\")\n", " model.connect(measurement_config[\"PARAMS\"])\n", " result = model.fit(data=enriched_sub, dependent_variable=\"revenue\")\n", " estimates_list.append(result.data[\"impact_estimates\"][\"att\"])\n", "\n", "print(\"Convergence analysis complete.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "from notebook_support import plot_convergence\n", "\n", "plot_convergence(\n", " sample_sizes,\n", " estimates_list,\n", " truth_list,\n", " xlabel=\"Number of Products\",\n", " ylabel=\"Treatment Effect (ATT)\",\n", " title=\"Nearest Neighbour Matching: Convergence of Estimate to True Effect\",\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }