Source code for lwdid.staggered.randomization

"""
Randomization inference for staggered difference-in-differences estimation.

This module implements permutation-based inference for staggered DiD settings
with multiple treatment cohorts. Randomization inference provides finite-sample
valid p-values without distributional assumptions by comparing the observed
test statistic to its permutation distribution.

Notes
-----
Randomization inference is useful when the number of treated or control units
is small and asymptotic approximations may be unreliable. The permutation
procedure shuffles cohort assignments while preserving the marginal distribution.

For overall effect inference, never-treated units are required as a consistent
reference group. For cohort-specific inference, the target cohort must remain
present after each permutation.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal
import warnings

import numpy as np
from numpy.typing import NDArray
import pandas as pd

from ..exceptions import RandomizationError


[docs] @dataclass class StaggeredRIResult: """ Container for staggered randomization inference results. This dataclass stores the output from permutation or bootstrap-based randomization inference procedures, including the p-value, replication diagnostics, and the full distribution of resampled statistics. Attributes ---------- p_value : float Two-sided p-value for the null hypothesis that the ATT equals zero. ri_method : str Resampling method used, either 'permutation' or 'bootstrap'. ri_reps : int Number of requested replications. ri_valid : int Number of valid replications that produced non-missing statistics. ri_failed : int Number of replications that failed due to estimation errors. observed_stat : float Observed ATT estimate being tested. permutation_stats : NDArray[np.float64] Array of ATT statistics from valid replications, excluding NaN values from failed replications. See Also -------- randomization_inference_staggered : Main function that produces this result. ri_overall_effect : Convenience wrapper for overall effect inference. ri_cohort_effect : Convenience wrapper for cohort-specific inference. """ p_value: float ri_method: str ri_reps: int ri_valid: int ri_failed: int observed_stat: float permutation_stats: NDArray[np.float64]
[docs] def __repr__(self) -> str: """Return a concise string representation of the result.""" return ( f"StaggeredRIResult(p_value={self.p_value:.4f}, " f"method='{self.ri_method}', valid={self.ri_valid}/{self.ri_reps})" )
[docs] def randomization_inference_staggered( data: pd.DataFrame, gvar: str, ivar: str, tvar: str, y: str, observed_att: float, target: Literal['overall', 'cohort', 'cohort_time'] = 'overall', target_cohort: int | None = None, target_period: int | None = None, ri_method: Literal['permutation', 'bootstrap'] = 'permutation', rireps: int = 1000, seed: int | None = None, rolling: str = 'demean', controls: list[str] | None = None, vce: str | None = None, cluster_var: str | None = None, n_never_treated: int = 0, ) -> StaggeredRIResult: """ Perform randomization inference for staggered DiD estimation. Tests the null hypothesis that the average treatment effect on the treated equals zero by permuting or bootstrapping treatment cohort assignments and computing the empirical distribution of the test statistic. Parameters ---------- data : pd.DataFrame Panel data in long format with unit, time, cohort, and outcome columns. gvar : str Column name for the treatment cohort variable. Units with missing, zero, or infinite values are treated as never-treated. ivar : str Column name for the unit identifier. tvar : str Column name for the time period variable. y : str Column name for the outcome variable. observed_att : float Observed ATT estimate to be tested against the null hypothesis. target : {'overall', 'cohort', 'cohort_time'}, default 'overall' Aggregation level for the target effect: - 'overall': Overall weighted average effect across all cohorts - 'cohort': Cohort-specific average effect (requires target_cohort) - 'cohort_time': Effect for a specific (g, r) pair (requires both target_cohort and target_period) target_cohort : int, optional Target cohort for 'cohort' or 'cohort_time' targets. target_period : int, optional Target time period for 'cohort_time' target. ri_method : {'permutation', 'bootstrap'}, default 'permutation' Resampling method for generating the null distribution: - 'permutation': Without-replacement permutation preserving cohort sizes (Fisher exact randomization inference) - 'bootstrap': With-replacement sampling from unit cohort assignments rireps : int, default 1000 Number of resampling replications. seed : int, optional Random seed for reproducibility of the resampling procedure. rolling : {'demean', 'detrend'}, default 'demean' Transformation method for removing pre-treatment variation: - 'demean': Subtract pre-treatment mean from each unit - 'detrend': Remove unit-specific linear time trend controls : list of str, optional Column names for control variables to include in estimation. vce : str, optional Variance-covariance estimator type for standard errors. cluster_var : str, optional Column name for clustering standard errors. n_never_treated : int, default 0 Number of never-treated units. Required for overall effect inference to ensure a consistent reference group across permutations. Returns ------- StaggeredRIResult Dataclass containing the p-value, replication counts, observed statistic, and array of permutation statistics. Raises ------ RandomizationError If input parameters are invalid, if there are insufficient units for inference, or if too few replications produce valid estimates. See Also -------- ri_overall_effect : Convenience wrapper for overall effect inference. ri_cohort_effect : Convenience wrapper for cohort-specific inference. Notes ----- The permutation procedure shuffles cohort assignments while preserving the marginal cohort distribution, generating the null distribution under the sharp null hypothesis of no treatment effect. The two-sided p-value is computed as: .. math:: p = \\frac{1}{R} \\sum_{r=1}^{R} \\mathbf{1}\\{|\\hat{\\tau}^{(r)}| \\geq |\\hat{\\tau}^{obs}|\\} where :math:`R` is the number of valid replications and :math:`\\hat{\\tau}^{(r)}` is the ATT from replication :math:`r`. A minimum of 50 valid replications (or 10% of rireps, whichever is larger) is required for reliable p-value computation. """ # ========================================================================= # Input Validation # ========================================================================= if rireps <= 0: raise RandomizationError("rireps must be positive") if target == 'cohort' and target_cohort is None: raise RandomizationError("target_cohort required when target='cohort'") if target == 'cohort_time' and (target_cohort is None or target_period is None): raise RandomizationError( "target_cohort and target_period required when target='cohort_time'" ) if ri_method not in ('permutation', 'bootstrap'): raise RandomizationError("ri_method must be 'permutation' or 'bootstrap'") # Overall effect requires never-treated units as consistent reference group # across all permutations; without them, the control group varies arbitrarily. if target == 'overall' and n_never_treated == 0: raise RandomizationError( "RI for overall effect requires never treated units. " "Use target='cohort_time' when no NT units exist." ) rolling_lower = rolling.lower() if rolling_lower not in ('demean', 'detrend'): raise RandomizationError( f"rolling must be 'demean' or 'detrend', got '{rolling}'" ) rng = np.random.default_rng(seed) # Cohort assignment is time-invariant; extract once per unit for efficiency. unit_gvar = data.groupby(ivar)[gvar].first() unit_ids = unit_gvar.index.tolist() n_units = len(unit_ids) # Permutation distribution requires sufficient units for meaningful inference. if n_units < 4: raise RandomizationError(f"Too few units for RI: N={n_units}") # Deferred imports to avoid circular dependencies between submodules. from .transformations import transform_staggered_demean, transform_staggered_detrend from .estimation import estimate_cohort_time_effects from .aggregation import aggregate_to_cohort, aggregate_to_overall, get_cohorts transform_func = ( transform_staggered_demean if rolling_lower == 'demean' else transform_staggered_detrend ) T_max = int(data[tvar].max()) # ========================================================================= # Resampling Loop # ========================================================================= # Pre-fill with NaN to distinguish failed replications from valid zeros; # NaN entries are excluded when computing the empirical p-value. perm_stats = np.empty(rireps, dtype=float) perm_stats.fill(np.nan) for rep in range(rireps): try: if ri_method == 'permutation': # Without-replacement shuffle preserves marginal cohort distribution. perm_idx = rng.permutation(n_units) perm_gvar = unit_gvar.values[perm_idx] else: # With-replacement bootstrap allows cohort frequency variation. boot_idx = rng.integers(0, n_units, size=n_units) perm_gvar = unit_gvar.values[boot_idx] perm_gvar_mapping = dict(zip(unit_ids, perm_gvar)) data_perm = data.copy() data_perm[gvar] = data_perm[ivar].map(perm_gvar_mapping) try: data_transformed = transform_func( data_perm, y, ivar, tvar, gvar ) except (ValueError, KeyError): perm_stats[rep] = np.nan continue if target == 'overall': try: result = aggregate_to_overall( data_transformed, gvar, ivar, tvar, transform_type=rolling_lower, vce=vce, cluster_var=cluster_var, ) perm_stats[rep] = result.att except (ValueError, KeyError): perm_stats[rep] = np.nan elif target == 'cohort': try: perm_cohorts = get_cohorts(data_transformed, gvar, ivar) # Permutation may reassign all units away from target cohort. if target_cohort not in perm_cohorts: perm_stats[rep] = np.nan continue results = aggregate_to_cohort( data_transformed, gvar, ivar, tvar, cohorts=[target_cohort], T_max=T_max, transform_type=rolling_lower, vce=vce, cluster_var=cluster_var, ) if results: perm_stats[rep] = results[0].att else: perm_stats[rep] = np.nan except (ValueError, KeyError): perm_stats[rep] = np.nan else: # target == 'cohort_time' try: ct_results = estimate_cohort_time_effects( data_transformed, gvar, ivar, tvar, controls=controls, vce=vce, cluster_var=cluster_var, transform_type=rolling_lower, ) target_result = [ r for r in ct_results if r.cohort == target_cohort and r.period == target_period ] if target_result: perm_stats[rep] = target_result[0].att else: perm_stats[rep] = np.nan except (ValueError, KeyError): perm_stats[rep] = np.nan except Exception: # Catch-all for unexpected failures; NaN exclusion handles these. perm_stats[rep] = np.nan # ========================================================================= # P-Value Computation # ========================================================================= valid_stats = perm_stats[~np.isnan(perm_stats)] n_valid = len(valid_stats) n_failed = rireps - n_valid # Insufficient valid replications yield unreliable p-value estimates. min_valid = max(50, int(0.1 * rireps)) if n_valid < min_valid: raise RandomizationError( f"Insufficient valid RI replications: {n_valid}/{rireps} " f"(need at least {min_valid}). " f"Consider increasing rireps or checking data quality." ) # Two-sided p-value under sharp null hypothesis of no treatment effect. p_value = float((np.abs(valid_stats) >= abs(observed_att)).mean()) # High failure rate may indicate data quality issues or model misspecification. if n_failed / rireps > 0.1: warnings.warn( f"Staggered RI: {n_failed}/{rireps} replications failed " f"({n_failed/rireps:.1%}). P-value computed using {n_valid} " f"valid replications.", UserWarning ) return StaggeredRIResult( p_value=p_value, ri_method=ri_method, ri_reps=rireps, ri_valid=n_valid, ri_failed=n_failed, observed_stat=observed_att, permutation_stats=valid_stats )
[docs] def ri_overall_effect( data: pd.DataFrame, gvar: str, ivar: str, tvar: str, y: str, observed_att: float, rolling: str = 'demean', ri_method: str = 'permutation', rireps: int = 1000, seed: int | None = None, vce: str | None = None, cluster_var: str | None = None, ) -> StaggeredRIResult: """ Perform randomization inference for the overall weighted ATT. This is a convenience wrapper around `randomization_inference_staggered` for testing the aggregate effect across all treatment cohorts. The overall effect is a cohort-share-weighted average of cohort-specific ATTs. Parameters ---------- data : pd.DataFrame Panel data in long format with unit, time, cohort, and outcome columns. gvar : str Column name for the treatment cohort variable. ivar : str Column name for the unit identifier. tvar : str Column name for the time period. y : str Column name for the outcome variable. observed_att : float Observed overall ATT estimate to test. rolling : {'demean', 'detrend'}, default 'demean' Transformation method for pre-treatment variation removal. ri_method : {'permutation', 'bootstrap'}, default 'permutation' Resampling method for null distribution generation. rireps : int, default 1000 Number of resampling replications. seed : int, optional Random seed for reproducibility. vce : str, optional Variance-covariance estimator type. cluster_var : str, optional Column name for clustering standard errors. Returns ------- StaggeredRIResult Randomization inference results including p-value and diagnostics. See Also -------- randomization_inference_staggered : Full-featured inference function. ri_cohort_effect : Inference for cohort-specific effects. """ from ..validation import is_never_treated unit_gvar = data.groupby(ivar)[gvar].first() nt_mask = unit_gvar.apply(is_never_treated) n_nt = int(nt_mask.sum()) return randomization_inference_staggered( data=data, gvar=gvar, ivar=ivar, tvar=tvar, y=y, observed_att=observed_att, target='overall', ri_method=ri_method, rireps=rireps, seed=seed, rolling=rolling, n_never_treated=n_nt, vce=vce, cluster_var=cluster_var, )
[docs] def ri_cohort_effect( data: pd.DataFrame, gvar: str, ivar: str, tvar: str, y: str, target_cohort: int, observed_att: float, rolling: str = 'demean', ri_method: str = 'permutation', rireps: int = 1000, seed: int | None = None, vce: str | None = None, cluster_var: str | None = None, ) -> StaggeredRIResult: """ Perform randomization inference for a cohort-specific ATT. This is a convenience wrapper around `randomization_inference_staggered` for testing the average effect for a specific treatment cohort. The cohort-specific ATT averages effects across all post-treatment periods for units first treated in the target cohort. Parameters ---------- data : pd.DataFrame Panel data in long format with unit, time, cohort, and outcome columns. gvar : str Column name for the treatment cohort variable. ivar : str Column name for the unit identifier. tvar : str Column name for the time period. y : str Column name for the outcome variable. target_cohort : int Treatment cohort (first treatment period) to test. observed_att : float Observed cohort-specific ATT estimate to test. rolling : {'demean', 'detrend'}, default 'demean' Transformation method for pre-treatment variation removal. ri_method : {'permutation', 'bootstrap'}, default 'permutation' Resampling method for null distribution generation. rireps : int, default 1000 Number of resampling replications. seed : int, optional Random seed for reproducibility. vce : str, optional Variance-covariance estimator type. cluster_var : str, optional Column name for clustering standard errors. Returns ------- StaggeredRIResult Randomization inference results including p-value and diagnostics. See Also -------- randomization_inference_staggered : Full-featured inference function. ri_overall_effect : Inference for overall weighted effect. """ from ..validation import is_never_treated unit_gvar = data.groupby(ivar)[gvar].first() nt_mask = unit_gvar.apply(is_never_treated) n_nt = int(nt_mask.sum()) return randomization_inference_staggered( data=data, gvar=gvar, ivar=ivar, tvar=tvar, y=y, observed_att=observed_att, target='cohort', target_cohort=target_cohort, ri_method=ri_method, rireps=rireps, seed=seed, rolling=rolling, n_never_treated=n_nt, vce=vce, cluster_var=cluster_var, )