Source code for lwdid.selection_diagnostics

"""
Selection mechanism diagnostics for unbalanced panel data.

This module provides diagnostic tools for assessing potential selection bias
in unbalanced panel data for difference-in-differences estimation.

The key assumption is that selection (missing data) may depend on unobserved
time-invariant heterogeneity, but cannot systematically depend on outcome
shocks in the untreated state. This is analogous to the standard fixed effects
assumption and is removed by the rolling transformation.

Main Functions
--------------
diagnose_selection_mechanism : Comprehensive selection mechanism diagnostics.
get_unit_missing_stats : Per-unit missing data statistics.
plot_missing_pattern : Visualize missing data patterns.

Data Classes
------------
SelectionDiagnostics : Complete diagnostic results.
BalanceStatistics : Panel balance metrics.
AttritionAnalysis : Attrition pattern analysis.
UnitMissingStats : Per-unit statistics.
SelectionTestResult : Statistical test results.

Enums
-----
MissingPattern : Missing data pattern classification (MCAR, MAR, MNAR).
SelectionRisk : Selection bias risk level (LOW, MEDIUM, HIGH).
"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional, Tuple

import numpy as np
import pandas as pd
from scipy import stats


# =============================================================================
# Enums
# =============================================================================

[docs] class MissingPattern(Enum): """ Missing data pattern classification based on Rubin's taxonomy. Attributes ---------- MCAR : str Missing Completely At Random - missingness is independent of all data, both observed and unobserved. This is the most benign pattern. MAR : str Missing At Random - missingness depends only on observed data. Acceptable under the selection mechanism assumption when controls are included. MNAR : str Missing Not At Random - missingness depends on unobserved data. This may violate the selection mechanism assumption if missingness depends on outcome shocks in the untreated state. UNKNOWN : str Pattern could not be determined with available data. Notes ----- The selection mechanism assumption requires that missingness may depend on unobserved time-invariant heterogeneity, but cannot systematically depend on time-varying outcome shocks. MCAR and MAR patterns are generally acceptable. MNAR patterns may be acceptable if missingness depends only on time-invariant factors (which are removed by the rolling transformation), but problematic if missingness depends on time-varying outcome shocks. """ MCAR = "missing_completely_at_random" MAR = "missing_at_random" MNAR = "missing_not_at_random" UNKNOWN = "unknown"
[docs] class SelectionRisk(Enum): """ Risk level for selection bias in ATT estimation. Attributes ---------- LOW : str Low risk - selection mechanism assumption likely holds. Proceed with estimation. MEDIUM : str Medium risk - some indicators suggest potential issues. Consider using detrending and sensitivity analysis. HIGH : str High risk - strong evidence of problematic selection. Results should be interpreted with caution. UNKNOWN : str Risk could not be assessed with available data. Notes ----- Risk assessment is based on multiple factors: - Missing data pattern (MCAR < MAR < MNAR) - Attrition rate (lower is better) - Differential attrition before/after treatment - Panel balance ratio The rolling transformation removes unit-specific averages, so selection is allowed to depend on unobserved time-constant heterogeneity, similar to the standard fixed effects assumption. """ LOW = "low" MEDIUM = "medium" HIGH = "high" UNKNOWN = "unknown"
# ============================================================================= # Data Classes # =============================================================================
[docs] @dataclass class AttritionAnalysis: """ Analysis of unit dropout patterns in panel data. Attributes ---------- n_units_complete : int Number of units with complete observations across all periods. n_units_partial : int Number of units with at least one missing period. attrition_rate : float Proportion of units with incomplete observations (n_partial / n_total). attrition_by_cohort : dict[int, float] Attrition rate by treatment cohort. Keys are cohort identifiers, values are attrition rates within each cohort. attrition_by_period : dict[int, float] Cumulative attrition rate by time period. Shows the proportion of units not observed at each time point. early_dropout_rate : float Rate of units that exit before the final period (last_obs < T_max). late_entry_rate : float Rate of units that enter after the first period (first_obs > T_min). dropout_before_treatment : int Number of treated units that dropout before their treatment period. High values may indicate anticipation effects. dropout_after_treatment : int Number of treated units that dropout after treatment starts. High values may indicate treatment-induced attrition. Notes ----- Differential attrition patterns (e.g., more dropout after treatment than before) may indicate selection related to treatment effects, which would violate the selection mechanism assumption. """ n_units_complete: int n_units_partial: int attrition_rate: float attrition_by_cohort: dict[int, float] = field(default_factory=dict) attrition_by_period: dict[int, float] = field(default_factory=dict) early_dropout_rate: float = 0.0 late_entry_rate: float = 0.0 dropout_before_treatment: int = 0 dropout_after_treatment: int = 0
[docs] @dataclass class BalanceStatistics: """ Panel balance statistics. Attributes ---------- is_balanced : bool True if all units have the same number of observations. n_units : int Total number of unique units in the panel. n_periods : int Total number of unique time periods in the panel. min_obs_per_unit : int Minimum observations across all units. max_obs_per_unit : int Maximum observations across all units. mean_obs_per_unit : float Average observations per unit. std_obs_per_unit : float Standard deviation of observations per unit. balance_ratio : float Ratio of min to max observations (1.0 = perfectly balanced). Lower values indicate more severe imbalance. units_below_demean_threshold : int Number of treated units with < 1 pre-treatment observation. These units cannot be used with demeaning. units_below_detrend_threshold : int Number of treated units with < 2 pre-treatment observations. These units cannot be used with detrending. pct_usable_demean : float Percentage of treated units usable for demeaning (0-100). pct_usable_detrend : float Percentage of treated units usable for detrending (0-100). Notes ----- For treatment cohort g in period r, the transformed outcome can only be computed if there are enough observed pre-treatment periods (t < g): - Demeaning requires at least one pre-treatment period to compute the mean. - Detrending requires at least two pre-treatment periods to estimate a linear trend. Units with insufficient pre-treatment observations are excluded from the corresponding transformation method. """ is_balanced: bool n_units: int n_periods: int min_obs_per_unit: int max_obs_per_unit: int mean_obs_per_unit: float std_obs_per_unit: float balance_ratio: float units_below_demean_threshold: int = 0 units_below_detrend_threshold: int = 0 pct_usable_demean: float = 100.0 pct_usable_detrend: float = 100.0
[docs] @dataclass class SelectionTestResult: """ Result of a statistical test for selection mechanism. Attributes ---------- test_name : str Name of the statistical test performed. statistic : float Test statistic value. pvalue : float P-value of the test. reject_null : bool Whether to reject the null hypothesis at alpha=0.05. interpretation : str Human-readable interpretation of the test result. details : dict[str, Any] Additional test-specific details (e.g., means, correlations). Notes ----- Common tests include: - Little's MCAR Test: Tests if data is missing completely at random - Selection on Observables: Tests if missingness depends on controls - Lagged Outcome Test: Tests if missingness depends on past outcomes """ test_name: str statistic: float pvalue: float reject_null: bool interpretation: str details: dict[str, Any] = field(default_factory=dict)
[docs] @dataclass class UnitMissingStats: """ Missing data statistics for a single unit. Attributes ---------- unit_id : Any Unit identifier. cohort : int | None Treatment cohort (None for never-treated units). is_treated : bool Whether the unit is ever treated. n_total_periods : int Total periods in the panel. n_observed : int Number of observed periods for this unit. n_missing : int Number of missing periods for this unit. missing_rate : float Proportion of missing periods (n_missing / n_total_periods). first_observed : int First period with observation. last_observed : int Last period with observation. observation_span : int Span from first to last observation (last - first + 1). n_pre_treatment : int | None Pre-treatment observations (treated units only). n_post_treatment : int | None Post-treatment observations (treated units only). pre_treatment_missing_rate : float | None Missing rate in pre-treatment period. post_treatment_missing_rate : float | None Missing rate in post-treatment period. can_use_demean : bool Whether unit has sufficient data for demeaning (≥1 pre-treatment obs). can_use_detrend : bool Whether unit has sufficient data for detrending (≥2 pre-treatment obs). reason_if_excluded : str | None Reason for exclusion if unit cannot be used. """ unit_id: Any cohort: int | None is_treated: bool n_total_periods: int n_observed: int n_missing: int missing_rate: float first_observed: int last_observed: int observation_span: int n_pre_treatment: int | None = None n_post_treatment: int | None = None pre_treatment_missing_rate: float | None = None post_treatment_missing_rate: float | None = None can_use_demean: bool = True can_use_detrend: bool = True reason_if_excluded: str | None = None
[docs] @dataclass class SelectionDiagnostics: """ Complete selection mechanism diagnostics for unbalanced panels. This class aggregates all diagnostic information about missing data patterns and potential selection bias in panel data for DiD estimation. Attributes ---------- missing_pattern : MissingPattern Classified missing data pattern (MCAR, MAR, MNAR, UNKNOWN). missing_pattern_confidence : float Confidence level (0-1) in the pattern classification. selection_risk : SelectionRisk Assessed risk level for selection bias. attrition_analysis : AttritionAnalysis Detailed attrition pattern analysis. balance_statistics : BalanceStatistics Panel balance statistics. recommendations : list[str] Actionable recommendations based on diagnostics. warnings : list[str] Warning messages about potential issues. missing_rate_overall : float Overall missing rate across all unit-periods. missing_rate_by_period : dict[int, float] Missing rate by time period. missing_rate_by_cohort : dict[int, float] Missing rate by treatment cohort. selection_tests : List[SelectionTestResult] Results of statistical tests for selection. unit_stats : List[UnitMissingStats] Per-unit missing data statistics. Notes ----- The selection mechanism assumption requires that selection may depend on unobserved time-invariant heterogeneity, but cannot systematically depend on time-varying outcome shocks. This is analogous to the standard fixed effects assumption. The rolling transformation removes unit-specific averages (or trends), which eliminates bias from selection on time-invariant factors. See Also -------- diagnose_selection_mechanism : Function to create this diagnostics object. """ missing_pattern: MissingPattern missing_pattern_confidence: float selection_risk: SelectionRisk attrition_analysis: AttritionAnalysis balance_statistics: BalanceStatistics recommendations: list[str] warnings: list[str] missing_rate_overall: float missing_rate_by_period: dict[int, float] missing_rate_by_cohort: dict[int, float] selection_tests: List[SelectionTestResult] = field(default_factory=list) unit_stats: List[UnitMissingStats] = field(default_factory=list)
[docs] def summary(self) -> str: """ Generate a human-readable summary of diagnostics. Returns ------- str Formatted summary string containing key diagnostic information, warnings, and recommendations. """ lines = [ "=" * 70, "SELECTION MECHANISM DIAGNOSTICS", "=" * 70, "", "PANEL BALANCE:", f" Status: {'Balanced' if self.balance_statistics.is_balanced else 'UNBALANCED'}", f" Units: {self.balance_statistics.n_units}", f" Periods: {self.balance_statistics.n_periods}", f" Observations per unit: {self.balance_statistics.min_obs_per_unit} - {self.balance_statistics.max_obs_per_unit}", f" Balance ratio: {self.balance_statistics.balance_ratio:.2%}", "", "MISSING DATA:", f" Overall missing rate: {self.missing_rate_overall:.2%}", f" Pattern classification: {self.missing_pattern.value}", f" Classification confidence: {self.missing_pattern_confidence:.0%}", "", "ATTRITION:", f" Attrition rate: {self.attrition_analysis.attrition_rate:.2%}", f" Complete units: {self.attrition_analysis.n_units_complete}", f" Partial units: {self.attrition_analysis.n_units_partial}", f" Late entry rate: {self.attrition_analysis.late_entry_rate:.2%}", f" Early dropout rate: {self.attrition_analysis.early_dropout_rate:.2%}", "", f"SELECTION RISK: {self.selection_risk.value.upper()}", "", "METHOD USABILITY:", f" Demean (≥1 pre-period): {self.balance_statistics.pct_usable_demean:.1f}% of treated units", f" Detrend (≥2 pre-periods): {self.balance_statistics.pct_usable_detrend:.1f}% of treated units", ] if self.selection_tests: lines.extend(["", "STATISTICAL TESTS:"]) for test in self.selection_tests: status = "REJECT" if test.reject_null else "FAIL TO REJECT" lines.append(f" {test.test_name}:") lines.append(f" Statistic: {test.statistic:.4f}, p-value: {test.pvalue:.4f} ({status})") if self.warnings: lines.extend(["", "⚠️ WARNINGS:"]) for w in self.warnings: lines.append(f" • {w}") if self.recommendations: lines.extend(["", "📋 RECOMMENDATIONS:"]) for r in self.recommendations: lines.append(f" → {r}") lines.extend([ "", "=" * 70, "SELECTION MECHANISM ASSUMPTION:", " Selection may depend on unobserved time-invariant heterogeneity,", " but cannot systematically depend on time-varying outcome shocks.", "=" * 70, ]) return "\n".join(lines)
[docs] def to_dict(self) -> dict[str, Any]: """ Convert diagnostics to dictionary format. Returns ------- dict[str, Any] Dictionary containing all diagnostic information. """ return { 'missing_pattern': self.missing_pattern.value, 'missing_pattern_confidence': self.missing_pattern_confidence, 'selection_risk': self.selection_risk.value, 'missing_rate_overall': self.missing_rate_overall, 'missing_rate_by_period': self.missing_rate_by_period, 'missing_rate_by_cohort': self.missing_rate_by_cohort, 'balance_statistics': { 'is_balanced': self.balance_statistics.is_balanced, 'n_units': self.balance_statistics.n_units, 'n_periods': self.balance_statistics.n_periods, 'balance_ratio': self.balance_statistics.balance_ratio, 'pct_usable_demean': self.balance_statistics.pct_usable_demean, 'pct_usable_detrend': self.balance_statistics.pct_usable_detrend, }, 'attrition_analysis': { 'attrition_rate': self.attrition_analysis.attrition_rate, 'n_units_complete': self.attrition_analysis.n_units_complete, 'n_units_partial': self.attrition_analysis.n_units_partial, }, 'recommendations': self.recommendations, 'warnings': self.warnings, }
# ============================================================================= # Helper Functions # ============================================================================= def _validate_diagnostic_inputs( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None, ) -> None: """ Validate inputs for diagnostic functions. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable column name. Raises ------ ValueError If required columns are missing or data is insufficient. TypeError If time variable is not numeric. """ # Check required columns exist required = [y, ivar, tvar] if gvar is not None: required.append(gvar) missing = [c for c in required if c not in data.columns] if missing: raise ValueError(f"Missing required columns: {missing}") # Check data types if not pd.api.types.is_numeric_dtype(data[tvar]): raise TypeError( f"Time variable '{tvar}' must be numeric. " f"Found type: {data[tvar].dtype}" ) # Check minimum data requirements if len(data) < 3: raise ValueError( "Insufficient data: need at least 3 observations for diagnostics." ) if data[ivar].nunique() < 2: raise ValueError( "Need at least 2 unique units for meaningful diagnostics." ) if data[tvar].nunique() < 2: raise ValueError( "Need at least 2 unique time periods for panel diagnostics." ) def _is_never_treated(gvar_value: Any, never_treated_values: List) -> bool: """ Check if a gvar value indicates never-treated status. Parameters ---------- gvar_value : Any Value from the gvar column. never_treated_values : List List of values indicating never-treated status. Returns ------- bool True if the value indicates never-treated status. """ if pd.isna(gvar_value): return True if gvar_value in never_treated_values: return True if isinstance(gvar_value, float) and np.isinf(gvar_value): return True return False def _compute_balance_statistics( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None, never_treated_values: List, ) -> BalanceStatistics: """ Compute panel balance statistics. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable column name. never_treated_values : List Values indicating never-treated status. Returns ------- BalanceStatistics Panel balance statistics. """ all_periods = sorted(data[tvar].unique()) n_periods = len(all_periods) # Count observations per unit obs_per_unit = data.groupby(ivar).size() n_units = len(obs_per_unit) is_balanced = obs_per_unit.nunique() == 1 min_obs = int(obs_per_unit.min()) max_obs = int(obs_per_unit.max()) mean_obs = float(obs_per_unit.mean()) std_obs = float(obs_per_unit.std()) if len(obs_per_unit) > 1 else 0.0 balance_ratio = min_obs / max_obs if max_obs > 0 else 0.0 # Count units below method thresholds units_below_demean = 0 units_below_detrend = 0 n_treated_units = 0 if gvar is not None: unit_gvar = data.drop_duplicates(subset=[ivar]).set_index(ivar)[gvar] for unit_id in data[ivar].unique(): g = unit_gvar.get(unit_id) # Skip never-treated if _is_never_treated(g, never_treated_values): continue n_treated_units += 1 # Count pre-treatment observations unit_data = data[data[ivar] == unit_id] n_pre = len(unit_data[unit_data[tvar] < g]) if n_pre < 1: units_below_demean += 1 if n_pre < 2: units_below_detrend += 1 # Calculate usability percentages if n_treated_units > 0: pct_demean = 100.0 * (1 - units_below_demean / n_treated_units) pct_detrend = 100.0 * (1 - units_below_detrend / n_treated_units) else: pct_demean = 100.0 pct_detrend = 100.0 return BalanceStatistics( is_balanced=is_balanced, n_units=n_units, n_periods=n_periods, min_obs_per_unit=min_obs, max_obs_per_unit=max_obs, mean_obs_per_unit=mean_obs, std_obs_per_unit=std_obs, balance_ratio=balance_ratio, units_below_demean_threshold=units_below_demean, units_below_detrend_threshold=units_below_detrend, pct_usable_demean=pct_demean, pct_usable_detrend=pct_detrend, ) def _compute_attrition_analysis( data: pd.DataFrame, ivar: str, tvar: str, gvar: str | None, never_treated_values: List, ) -> AttritionAnalysis: """ Compute attrition pattern analysis. Parameters ---------- data : pd.DataFrame Panel data in long format. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable column name. never_treated_values : List Values indicating never-treated status. Returns ------- AttritionAnalysis Attrition pattern analysis. """ all_periods = sorted(data[tvar].unique()) T_min, T_max = min(all_periods), max(all_periods) n_periods = len(all_periods) # Count observations per unit obs_per_unit = data.groupby(ivar).size() n_units = len(obs_per_unit) # Identify complete vs partial units complete_units = obs_per_unit[obs_per_unit == n_periods].index n_complete = len(complete_units) n_partial = n_units - n_complete attrition_rate = n_partial / n_units if n_units > 0 else 0.0 # Attrition by cohort attrition_by_cohort = {} if gvar is not None: unit_gvar = data.drop_duplicates(subset=[ivar]).set_index(ivar)[gvar] for g in data[gvar].dropna().unique(): if _is_never_treated(g, never_treated_values): continue cohort_units = unit_gvar[unit_gvar == g].index cohort_complete = len([u for u in cohort_units if u in complete_units]) cohort_total = len(cohort_units) if cohort_total > 0: attrition_by_cohort[int(g)] = 1 - cohort_complete / cohort_total # Attrition by period (proportion not observed at each time) attrition_by_period = {} for t in all_periods: units_observed_at_t = data[data[tvar] == t][ivar].nunique() attrition_by_period[int(t)] = 1 - units_observed_at_t / n_units # Early dropout / late entry first_obs = data.groupby(ivar)[tvar].min() last_obs = data.groupby(ivar)[tvar].max() late_entry_rate = float((first_obs > T_min).mean()) early_dropout_rate = float((last_obs < T_max).mean()) # Dropout before/after treatment dropout_before = 0 dropout_after = 0 if gvar is not None: unit_gvar = data.drop_duplicates(subset=[ivar]).set_index(ivar)[gvar] for unit_id in data[ivar].unique(): g = unit_gvar.get(unit_id) if _is_never_treated(g, never_treated_values): continue unit_last = last_obs.get(unit_id) if unit_last < g: dropout_before += 1 elif unit_last < T_max: dropout_after += 1 return AttritionAnalysis( n_units_complete=n_complete, n_units_partial=n_partial, attrition_rate=attrition_rate, attrition_by_cohort=attrition_by_cohort, attrition_by_period=attrition_by_period, early_dropout_rate=early_dropout_rate, late_entry_rate=late_entry_rate, dropout_before_treatment=dropout_before, dropout_after_treatment=dropout_after, ) def _classify_missing_pattern( data: pd.DataFrame, y: str, ivar: str, tvar: str, controls: Optional[list[str]] = None, ) -> Tuple[MissingPattern, float, List[SelectionTestResult]]: """ Classify missing data pattern using statistical tests. Implements a simplified version of Little's MCAR test and auxiliary regressions to classify the missing data mechanism. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. controls : list[str] or None Control variable column names. Returns ------- Tuple[MissingPattern, float, List[SelectionTestResult]] - Classified missing pattern - Confidence in classification (0-1) - List of test results """ tests = [] # Create full panel and missing indicator all_periods = sorted(data[tvar].unique()) all_units = data[ivar].unique() full_index = pd.MultiIndex.from_product( [all_units, all_periods], names=[ivar, tvar] ) full_panel = pd.DataFrame(index=full_index).reset_index() merged = full_panel.merge(data, on=[ivar, tvar], how='left') # M_it = 1 if missing merged['_missing'] = merged[y].isna().astype(int) # If no missing data, return MCAR with high confidence if merged['_missing'].sum() == 0: return MissingPattern.MCAR, 1.0, [] # ========================================================================= # Test 1: Simplified Little's MCAR Test # Compare mean outcome between complete and incomplete units # ========================================================================= obs_per_unit = merged.groupby(ivar)['_missing'].sum() complete_units = obs_per_unit[obs_per_unit == 0].index incomplete_units = obs_per_unit[obs_per_unit > 0].index if len(complete_units) >= 5 and len(incomplete_units) >= 5: # Get observed Y values for each group complete_y = data[data[ivar].isin(complete_units)][y].dropna() incomplete_y = data[data[ivar].isin(incomplete_units)][y].dropna() if len(complete_y) > 1 and len(incomplete_y) > 1: complete_mean = complete_y.mean() incomplete_mean = incomplete_y.mean() complete_std = complete_y.std() incomplete_std = incomplete_y.std() n1, n2 = len(complete_units), len(incomplete_units) # Pooled standard error pooled_se = np.sqrt( complete_std**2 / n1 + incomplete_std**2 / n2 ) if pooled_se > 0: t_stat = (complete_mean - incomplete_mean) / pooled_se df = n1 + n2 - 2 pvalue = 2 * (1 - stats.t.cdf(abs(t_stat), df)) tests.append(SelectionTestResult( test_name="Simplified Little's MCAR Test", statistic=float(t_stat), pvalue=float(pvalue), reject_null=pvalue < 0.05, interpretation=( "Reject MCAR: significant difference in outcomes between " "complete and incomplete units" if pvalue < 0.05 else "Cannot reject MCAR: no significant difference detected" ), details={ 'complete_mean': float(complete_mean), 'incomplete_mean': float(incomplete_mean), 'n_complete': n1, 'n_incomplete': n2, } )) # ========================================================================= # Test 2: Selection on Observables (MAR test) # Regress unit-level missing rate on controls # ========================================================================= if controls and len(controls) > 0: # Get unit-level controls (first observation per unit) available_controls = [c for c in controls if c in data.columns] if available_controls: unit_controls = data.drop_duplicates(subset=[ivar])[ available_controls + [ivar] ].dropna() # Compute unit-level missing rate unit_missing = merged.groupby(ivar)['_missing'].mean().reset_index() unit_missing.columns = [ivar, '_unit_missing_rate'] test_data = unit_missing.merge(unit_controls, on=ivar) if len(test_data) > len(available_controls) + 2: X = test_data[available_controls].values X = np.column_stack([np.ones(len(X)), X]) # Add constant y_miss = test_data['_unit_missing_rate'].values try: # OLS regression beta, residuals, rank, s = np.linalg.lstsq(X, y_miss, rcond=None) # Compute R-squared y_pred = X @ beta SS_res = np.sum((y_miss - y_pred)**2) SS_tot = np.sum((y_miss - y_miss.mean())**2) R2 = 1 - SS_res / SS_tot if SS_tot > 0 else 0 # F-test for joint significance k = len(available_controls) n = len(y_miss) if R2 < 1 and n > k + 1: F_stat = (R2 / k) / ((1 - R2) / (n - k - 1)) pvalue = 1 - stats.f.cdf(F_stat, k, n - k - 1) tests.append(SelectionTestResult( test_name="Selection on Observables (MAR) Test", statistic=float(F_stat), pvalue=float(pvalue), reject_null=pvalue < 0.05, interpretation=( "Selection depends on observed controls (MAR)" if pvalue < 0.05 else "No evidence of selection on observed controls" ), details={ 'R2': float(R2), 'controls': available_controls, } )) except (np.linalg.LinAlgError, ValueError): pass # ========================================================================= # Test 3: Selection on Lagged Outcomes (MNAR indicator) # Test if missingness correlates with lagged Y # ========================================================================= data_sorted = data.sort_values([ivar, tvar]) data_sorted['_y_lag'] = data_sorted.groupby(ivar)[y].shift(1) # Merge with missing indicator lag_test = merged.merge( data_sorted[[ivar, tvar, '_y_lag']], on=[ivar, tvar], how='left' ) lag_test = lag_test.dropna(subset=['_y_lag']) if len(lag_test) > 10: # Point-biserial correlation try: corr, pvalue = stats.pointbiserialr( lag_test['_missing'].values, lag_test['_y_lag'].values ) tests.append(SelectionTestResult( test_name="Selection on Lagged Outcome (MNAR) Test", statistic=float(corr), pvalue=float(pvalue), reject_null=pvalue < 0.05, interpretation=( "WARNING: Missingness correlates with lagged outcomes. " "This suggests potential MNAR and may violate the selection " "mechanism assumption." if pvalue < 0.05 else "No evidence of selection on lagged outcomes" ), details={'correlation': float(corr)} )) except (ValueError, TypeError): pass # ========================================================================= # Classify Pattern Based on Test Results # ========================================================================= mcar_rejected = any( t.test_name.startswith("Simplified Little") and t.reject_null for t in tests ) mar_detected = any( t.test_name.startswith("Selection on Observables") and t.reject_null for t in tests ) mnar_detected = any( t.test_name.startswith("Selection on Lagged") and t.reject_null for t in tests ) if mnar_detected: pattern = MissingPattern.MNAR confidence = 0.7 # MNAR is hard to confirm definitively elif mar_detected: pattern = MissingPattern.MAR confidence = 0.8 elif not mcar_rejected: pattern = MissingPattern.MCAR confidence = 0.9 if len(tests) > 0 else 0.5 else: pattern = MissingPattern.UNKNOWN confidence = 0.3 return pattern, confidence, tests def _assess_selection_risk( missing_pattern: MissingPattern, attrition_analysis: AttritionAnalysis, balance_statistics: BalanceStatistics, selection_tests: List[SelectionTestResult], ) -> Tuple[SelectionRisk, list[str], list[str]]: """ Assess overall selection bias risk based on multiple indicators. Risk Assessment Criteria: LOW Risk (acceptable): - Missing pattern is MCAR or MAR - Attrition rate < 10% - No significant selection on lagged outcomes - Balance ratio > 0.8 MEDIUM Risk (caution): - Missing pattern is MAR with moderate attrition - Attrition rate 10-30% - Some evidence of differential attrition by cohort - Balance ratio 0.5-0.8 HIGH Risk (problematic): - Missing pattern is MNAR - Attrition rate > 30% - Strong selection on lagged outcomes - Differential dropout before/after treatment - Balance ratio < 0.5 Parameters ---------- missing_pattern : MissingPattern Classified missing data pattern. attrition_analysis : AttritionAnalysis Attrition pattern analysis. balance_statistics : BalanceStatistics Panel balance statistics. selection_tests : List[SelectionTestResult] Statistical test results. Returns ------- Tuple[SelectionRisk, list[str], list[str]] - Assessed risk level - List of recommendations - List of warnings """ recommendations = [] warnings = [] risk_score = 0 # 0-100 scale # ========================================================================= # Factor 1: Missing Pattern (weight: 30%) # ========================================================================= if missing_pattern == MissingPattern.MCAR: risk_score += 0 elif missing_pattern == MissingPattern.MAR: risk_score += 15 elif missing_pattern == MissingPattern.MNAR: risk_score += 30 warnings.append( "Missing data pattern suggests selection on unobservables. " "This may violate the selection mechanism assumption." ) else: # UNKNOWN risk_score += 10 # ========================================================================= # Factor 2: Attrition Rate (weight: 25%) # ========================================================================= attrition = attrition_analysis.attrition_rate if attrition < 0.10: risk_score += 0 elif attrition < 0.30: risk_score += 12 else: risk_score += 25 warnings.append( f"High attrition rate ({attrition:.1%}). Consider using " "detrending which is more robust to selection on trends." ) # ========================================================================= # Factor 3: Differential Attrition (weight: 25%) # ========================================================================= dropout_before = attrition_analysis.dropout_before_treatment dropout_after = attrition_analysis.dropout_after_treatment if dropout_after > 0 and dropout_before > 0: if dropout_after > dropout_before * 2: risk_score += 25 warnings.append( f"Significantly more dropout after treatment ({dropout_after}) " f"than before ({dropout_before}). This may indicate selection " "related to treatment effects." ) elif dropout_after > dropout_before * 1.5: risk_score += 15 # ========================================================================= # Factor 4: Balance Ratio (weight: 20%) # ========================================================================= balance = balance_statistics.balance_ratio if balance > 0.8: risk_score += 0 elif balance > 0.5: risk_score += 10 else: risk_score += 20 warnings.append( f"Low balance ratio ({balance:.1%}). Some units have much fewer " "observations than others." ) # ========================================================================= # Determine Risk Level and Generate Recommendations # ========================================================================= if risk_score < 25: risk = SelectionRisk.LOW recommendations.append( "Selection risk is low. Proceed with estimation. " "The selection mechanism assumption appears reasonable." ) elif risk_score < 50: risk = SelectionRisk.MEDIUM recommendations.extend([ "Moderate selection risk detected. Consider the following:", "1. Use rolling='detrend' for additional robustness to selection on trends", "2. Compare results with a balanced subsample as sensitivity check", "3. Report both demean and detrend results for transparency", ]) else: risk = SelectionRisk.HIGH recommendations.extend([ "High selection risk detected. Strongly recommend:", "1. Use rolling='detrend' method (more robust to selection on trends)", "2. Conduct sensitivity analysis with balanced subsample", "3. Report diagnostics and discuss potential selection bias", "4. Consider alternative identification strategies if possible", ]) # Add method-specific recommendations if balance_statistics.pct_usable_detrend < 90: recommendations.append( f"Note: Only {balance_statistics.pct_usable_detrend:.1f}% of treated " "units have sufficient pre-treatment periods for detrending. " "Consider using demean if detrending excludes too many units." ) return risk, recommendations, warnings def _compute_missing_rates( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None, never_treated_values: List, ) -> Tuple[float, dict[int, float], dict[int, float]]: """ Compute missing rates overall, by period, and by cohort. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable column name. never_treated_values : List Values indicating never-treated status. Returns ------- Tuple[float, dict[int, float], dict[int, float]] - Overall missing rate - Missing rate by period - Missing rate by cohort """ all_periods = sorted(data[tvar].unique()) all_units = data[ivar].unique() # Create full panel index full_index = pd.MultiIndex.from_product( [all_units, all_periods], names=[ivar, tvar] ) full_panel = pd.DataFrame(index=full_index).reset_index() merged = full_panel.merge(data[[ivar, tvar, y]], on=[ivar, tvar], how='left') # Overall missing rate missing_rate_overall = float(merged[y].isna().mean()) # Missing rate by period missing_rate_by_period = {} for t in all_periods: period_data = merged[merged[tvar] == t] missing_rate_by_period[int(t)] = float(period_data[y].isna().mean()) # Missing rate by cohort missing_rate_by_cohort = {} if gvar is not None: unit_gvar = data.drop_duplicates(subset=[ivar]).set_index(ivar)[gvar] for g in data[gvar].dropna().unique(): if _is_never_treated(g, never_treated_values): continue cohort_units = unit_gvar[unit_gvar == g].index cohort_data = merged[merged[ivar].isin(cohort_units)] if len(cohort_data) > 0: missing_rate_by_cohort[int(g)] = float(cohort_data[y].isna().mean()) return missing_rate_overall, missing_rate_by_period, missing_rate_by_cohort def _compute_unit_stats( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None, never_treated_values: List, ) -> List[UnitMissingStats]: """ Compute per-unit missing data statistics. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable column name. never_treated_values : List Values indicating never-treated status. Returns ------- List[UnitMissingStats] Per-unit missing data statistics. """ all_periods = sorted(data[tvar].unique()) n_total_periods = len(all_periods) T_min, T_max = min(all_periods), max(all_periods) unit_stats = [] # Get unit-level gvar if available unit_gvar = None if gvar is not None: unit_gvar = data.drop_duplicates(subset=[ivar]).set_index(ivar)[gvar] for unit_id in data[ivar].unique(): unit_data = data[data[ivar] == unit_id] # Basic statistics observed_periods = unit_data[tvar].unique() n_observed = len(observed_periods) n_missing = n_total_periods - n_observed missing_rate = n_missing / n_total_periods if n_total_periods > 0 else 0.0 first_observed = int(unit_data[tvar].min()) last_observed = int(unit_data[tvar].max()) observation_span = last_observed - first_observed + 1 # Cohort information cohort = None is_treated = False n_pre_treatment = None n_post_treatment = None pre_missing_rate = None post_missing_rate = None can_use_demean = True can_use_detrend = True reason_if_excluded = None if unit_gvar is not None: g = unit_gvar.get(unit_id) if not _is_never_treated(g, never_treated_values): cohort = int(g) is_treated = True # Count pre/post treatment observations pre_periods = [t for t in all_periods if t < g] post_periods = [t for t in all_periods if t >= g] n_pre_treatment = len([t for t in observed_periods if t < g]) n_post_treatment = len([t for t in observed_periods if t >= g]) if len(pre_periods) > 0: pre_missing_rate = 1 - n_pre_treatment / len(pre_periods) if len(post_periods) > 0: post_missing_rate = 1 - n_post_treatment / len(post_periods) # Check method usability if n_pre_treatment < 1: can_use_demean = False reason_if_excluded = "No pre-treatment observations" if n_pre_treatment < 2: can_use_detrend = False if reason_if_excluded is None: reason_if_excluded = "Fewer than 2 pre-treatment observations" unit_stats.append(UnitMissingStats( unit_id=unit_id, cohort=cohort, is_treated=is_treated, n_total_periods=n_total_periods, n_observed=n_observed, n_missing=n_missing, missing_rate=missing_rate, first_observed=first_observed, last_observed=last_observed, observation_span=observation_span, n_pre_treatment=n_pre_treatment, n_post_treatment=n_post_treatment, pre_treatment_missing_rate=pre_missing_rate, post_treatment_missing_rate=post_missing_rate, can_use_demean=can_use_demean, can_use_detrend=can_use_detrend, reason_if_excluded=reason_if_excluded, )) return unit_stats # ============================================================================= # Main Public Functions # =============================================================================
[docs] def diagnose_selection_mechanism( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None = None, controls: Optional[list[str]] = None, never_treated_values: Optional[List] = None, verbose: bool = True, ) -> SelectionDiagnostics: """ Diagnose potential selection mechanism violations in unbalanced panels. This function implements diagnostic procedures to assess whether the selection mechanism assumption is likely to hold. The key assumption is that selection (missing data) may depend on time-invariant heterogeneity but not on time-varying outcome shocks. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str, optional Cohort variable for staggered designs. If None, assumes common timing and skips cohort-specific diagnostics. controls : list of str, optional Control variable column names for additional diagnostics. never_treated_values : list, optional Values in gvar indicating never-treated units. Default: [0, np.inf]. verbose : bool, default True Whether to print diagnostic summary. Returns ------- SelectionDiagnostics Comprehensive diagnostic results including: - missing_pattern: Classified pattern (MCAR, MAR, MNAR) - selection_risk: Risk level for selection bias - attrition_analysis: Detailed attrition patterns - balance_statistics: Panel balance metrics - recommendations: Actionable suggestions - selection_tests: Statistical test results Notes ----- The function performs several diagnostic procedures: 1. **Balance Statistics**: Computes panel balance metrics and identifies units that cannot be used for demeaning (< 1 pre-period) or detrending (< 2 pre-periods). 2. **Attrition Analysis**: Analyzes dropout patterns by cohort and time, distinguishing between early dropout (before treatment) and late dropout (after treatment). 3. **Missing Pattern Classification**: Uses Little's MCAR test and auxiliary regressions to classify the missing data mechanism. 4. **Selection Risk Assessment**: Combines multiple indicators to assess the overall risk of selection bias. The selection mechanism assumption requires that selection may depend on unobserved time-constant heterogeneity (which is removed by the rolling transformation, similar to the fixed effects estimator), but cannot systematically depend on time-varying outcome shocks. See Also -------- plot_missing_pattern : Visualize missing data patterns. get_unit_missing_stats : Get per-unit missing statistics as DataFrame. """ # Validate inputs _validate_diagnostic_inputs(data, y, ivar, tvar, gvar) if never_treated_values is None: never_treated_values = [0, np.inf] # Step 1: Compute balance statistics balance_stats = _compute_balance_statistics( data, y, ivar, tvar, gvar, never_treated_values ) # Step 2: Compute attrition analysis attrition = _compute_attrition_analysis( data, ivar, tvar, gvar, never_treated_values ) # Step 3: Classify missing pattern pattern, confidence, tests = _classify_missing_pattern( data, y, ivar, tvar, controls ) # Step 4: Assess selection risk risk, recommendations, warnings = _assess_selection_risk( pattern, attrition, balance_stats, tests ) # Step 5: Compute missing rates missing_overall, missing_by_period, missing_by_cohort = _compute_missing_rates( data, y, ivar, tvar, gvar, never_treated_values ) # Step 6: Compute unit statistics unit_stats = _compute_unit_stats( data, y, ivar, tvar, gvar, never_treated_values ) # Step 7: Assemble result result = SelectionDiagnostics( missing_pattern=pattern, missing_pattern_confidence=confidence, selection_risk=risk, attrition_analysis=attrition, balance_statistics=balance_stats, recommendations=recommendations, warnings=warnings, missing_rate_overall=missing_overall, missing_rate_by_period=missing_by_period, missing_rate_by_cohort=missing_by_cohort, selection_tests=tests, unit_stats=unit_stats, ) if verbose: print(result.summary()) return result
[docs] def get_unit_missing_stats( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None = None, never_treated_values: Optional[List] = None, ) -> pd.DataFrame: """ Compute per-unit missing data statistics as a DataFrame. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str, optional Cohort variable for staggered designs. never_treated_values : list, optional Values indicating never-treated units. Default: [0, np.inf]. Returns ------- pd.DataFrame DataFrame with one row per unit containing: - unit_id: Unit identifier - cohort: Treatment cohort (NaN for never-treated) - is_treated: Whether unit is ever treated - n_observed: Number of observed periods - n_missing: Number of missing periods - missing_rate: Proportion missing - n_pre_treatment: Pre-treatment observations - n_post_treatment: Post-treatment observations - can_use_demean: Sufficient data for demeaning - can_use_detrend: Sufficient data for detrending See Also -------- diagnose_selection_mechanism : Comprehensive diagnostics. """ _validate_diagnostic_inputs(data, y, ivar, tvar, gvar) if never_treated_values is None: never_treated_values = [0, np.inf] unit_stats = _compute_unit_stats( data, y, ivar, tvar, gvar, never_treated_values ) # Convert to DataFrame records = [] for us in unit_stats: records.append({ 'unit_id': us.unit_id, 'cohort': us.cohort, 'is_treated': us.is_treated, 'n_total_periods': us.n_total_periods, 'n_observed': us.n_observed, 'n_missing': us.n_missing, 'missing_rate': us.missing_rate, 'first_observed': us.first_observed, 'last_observed': us.last_observed, 'observation_span': us.observation_span, 'n_pre_treatment': us.n_pre_treatment, 'n_post_treatment': us.n_post_treatment, 'pre_treatment_missing_rate': us.pre_treatment_missing_rate, 'post_treatment_missing_rate': us.post_treatment_missing_rate, 'can_use_demean': us.can_use_demean, 'can_use_detrend': us.can_use_detrend, 'reason_if_excluded': us.reason_if_excluded, }) return pd.DataFrame(records)
[docs] def plot_missing_pattern( data: pd.DataFrame, ivar: str, tvar: str, y: str | None = None, gvar: str | None = None, sort_by: str = 'cohort', figsize: Tuple[float, float] = (12, 8), cmap: str = 'RdYlGn', show_cohort_lines: bool = True, never_treated_values: Optional[List] = None, max_units: int = 200, ax: Optional[Any] = None, ) -> Any: """ Visualize missing data patterns in panel data. Creates a heatmap showing observation availability across units and time. Units can be sorted by cohort, missing rate, or unit ID. Parameters ---------- data : pd.DataFrame Panel data in long format. ivar : str Unit identifier column name. tvar : str Time variable column name. y : str, optional Outcome variable. If provided, checks for missing Y values. If None, checks for missing rows. gvar : str, optional Cohort variable. If provided, shows treatment timing. sort_by : str, default 'cohort' How to sort units: 'cohort', 'missing_rate', 'unit_id'. figsize : tuple, default (12, 8) Figure size in inches. cmap : str, default 'RdYlGn' Colormap for the heatmap (not used, custom colors applied). show_cohort_lines : bool, default True Whether to show treatment timing lines. never_treated_values : list, optional Values indicating never-treated units. Default: [0, np.inf]. max_units : int, default 200 Maximum number of units to display. If more units exist, a random sample is shown. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. Returns ------- matplotlib.figure.Figure Figure containing the missing pattern heatmap. Notes ----- The heatmap uses the following color coding: - Green: Observed (Y value present) - Red: Missing (Y value missing or row absent) - Black line: Treatment timing (if gvar provided) See Also -------- diagnose_selection_mechanism : Comprehensive diagnostics. """ try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.colors import ListedColormap except ImportError: raise ImportError( "matplotlib is required for visualization. " "Install it with: pip install matplotlib" ) if never_treated_values is None: never_treated_values = [0, np.inf] # Get all units and periods all_units = list(data[ivar].unique()) all_periods = sorted(data[tvar].unique()) # Sample units if too many if len(all_units) > max_units: import warnings warnings.warn( f"Panel has {len(all_units)} units. Showing random sample of {max_units}.", UserWarning ) np.random.seed(42) all_units = list(np.random.choice(all_units, max_units, replace=False)) # Get unit-level gvar for sorting unit_gvar = None if gvar is not None and gvar in data.columns: unit_gvar = data.drop_duplicates(subset=[ivar]).set_index(ivar)[gvar] # Build observation matrix obs_matrix = np.zeros((len(all_units), len(all_periods))) for i, unit in enumerate(all_units): unit_data = data[data[ivar] == unit] for j, period in enumerate(all_periods): period_data = unit_data[unit_data[tvar] == period] if len(period_data) > 0: if y is None or period_data[y].notna().any(): obs_matrix[i, j] = 1 # Observed # Sort units if sort_by == 'cohort' and unit_gvar is not None: sort_keys = [] for u in all_units: g = unit_gvar.get(u) if _is_never_treated(g, never_treated_values): sort_keys.append(np.inf) else: sort_keys.append(g if pd.notna(g) else np.inf) sort_idx = np.argsort(sort_keys) elif sort_by == 'missing_rate': missing_rates = 1 - obs_matrix.mean(axis=1) sort_idx = np.argsort(missing_rates) else: # 'unit_id' sort_idx = np.arange(len(all_units)) obs_matrix = obs_matrix[sort_idx] sorted_units = [all_units[i] for i in sort_idx] # Create figure if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure # Plot heatmap with custom colors custom_cmap = ListedColormap(['#d73027', '#1a9850']) # Red=missing, Green=observed im = ax.imshow(obs_matrix, aspect='auto', cmap=custom_cmap, vmin=0, vmax=1) # Add cohort lines if show_cohort_lines and unit_gvar is not None: for i, unit in enumerate(sorted_units): g = unit_gvar.get(unit) if not _is_never_treated(g, never_treated_values) and pd.notna(g): try: j = list(all_periods).index(g) ax.plot([j - 0.5, j - 0.5], [i - 0.5, i + 0.5], 'k-', linewidth=0.8, alpha=0.7) except ValueError: pass # Labels ax.set_xlabel('Time Period', fontsize=11) ax.set_ylabel('Unit', fontsize=11) ax.set_title('Panel Data Observation Pattern\n(Green=Observed, Red=Missing)', fontsize=12) # X-axis ticks if len(all_periods) <= 25: ax.set_xticks(range(len(all_periods))) ax.set_xticklabels(all_periods, rotation=45, ha='right') else: # Show subset of ticks tick_step = max(1, len(all_periods) // 10) tick_positions = range(0, len(all_periods), tick_step) ax.set_xticks(tick_positions) ax.set_xticklabels([all_periods[i] for i in tick_positions], rotation=45, ha='right') # Y-axis: hide individual unit labels if too many if len(sorted_units) > 50: ax.set_yticks([]) # Legend legend_elements = [ mpatches.Patch(facecolor='#1a9850', label='Observed'), mpatches.Patch(facecolor='#d73027', label='Missing'), ] if show_cohort_lines and unit_gvar is not None: legend_elements.append( plt.Line2D([0], [0], color='black', linewidth=1, label='Treatment Start') ) ax.legend(handles=legend_elements, loc='upper right', fontsize=9) plt.tight_layout() return fig