Source code for lwdid.staggered.estimation_pre

"""
Pre-treatment period ATT estimation for staggered difference-in-differences.

This module implements treatment effect estimation for pre-treatment periods
in staggered adoption settings. Pre-treatment ATT estimates are used for:

1. Event study visualization with pre-treatment effects
2. Parallel trends assumption testing
3. Detection of anticipation effects or dynamic selection

Under the parallel trends assumption and no anticipation, pre-treatment
ATT estimates should be statistically indistinguishable from zero.

The estimation methodology uses rolling transformations:

- For each cohort g and pre-treatment period t < g, the transformation
  uses future pre-treatment periods {t+1, ..., g-1} to compute baselines.
- The anchor point at t = g-1 (event time e = -1) is set to exactly 0
  by construction, serving as the reference for pre-treatment dynamics.
- Control groups are dynamically defined based on treatment timing:
  units first treated after period t plus never-treated units.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass

import numpy as np
import pandas as pd

from .control_groups import (
    ControlGroupStrategy,
    get_valid_control_units,
)
from .transformations import get_cohorts
from .transformations_pre import get_pre_treatment_periods_for_cohort
from .estimation import run_ols_regression, _estimate_single_effect_ipwra, _estimate_single_effect_psm


[docs] @dataclass(frozen=True) class PreTreatmentEffect: """ Container for a pre-treatment period ATT estimate. Stores the ATT for treatment cohort g at pre-treatment period t < g, along with inference statistics. Under parallel trends, these estimates should be approximately zero. Attributes ---------- cohort : int Treatment cohort identifier (first treatment period g). period : int Calendar time period t (where t < g). event_time : int Event time relative to treatment onset (e = t - g, always negative). att : float Estimated pre-treatment ATT (should be ~0 under parallel trends). se : float Standard error of the ATT estimate. ci_lower : float Lower bound of confidence interval. ci_upper : float Upper bound of confidence interval. t_stat : float t-statistic for testing H0: ATT = 0. pvalue : float Two-sided p-value. n_treated : int Number of treated units in estimation sample. n_control : int Number of control units in estimation sample. is_anchor : bool True if this is the anchor point (t = g-1, e = -1). Anchor points have ATT = 0, SE = 0 by convention. rolling_window_size : int Number of future periods used in transformation {t+1, ..., g-1}. Notes ----- The dataclass is frozen (immutable) to ensure estimation results cannot be accidentally modified after creation. """ cohort: int period: int event_time: int att: float se: float ci_lower: float ci_upper: float t_stat: float pvalue: float n_treated: int n_control: int is_anchor: bool = False rolling_window_size: int = 0 df_inference: int = 0
[docs] def estimate_pre_treatment_effects( data_transformed: pd.DataFrame, gvar: str, ivar: str, tvar: str, controls: list[str] | None = None, vce: str | None = None, cluster_var: str | None = None, control_strategy: str = 'not_yet_treated', never_treated_values: list | None = None, min_obs: int = 3, min_treated: int = 1, min_control: int = 1, alpha: float = 0.05, estimator: str = 'ra', transform_type: str = 'demean', propensity_controls: list[str] | None = None, trim_threshold: float = 0.01, se_method: str = 'analytical', n_neighbors: int = 1, with_replacement: bool = True, caliper: float | None = None, ) -> list[PreTreatmentEffect]: """ Estimate pre-treatment effects for all valid cohort-period pairs. Iterates over all treatment cohorts and their pre-treatment periods, estimating the ATT for each (cohort, period) combination. Under the parallel trends assumption, these estimates should be approximately zero. Parameters ---------- data_transformed : pd.DataFrame Panel data containing pre-treatment transformed outcome columns generated by ``transform_staggered_demean_pre`` or ``transform_staggered_detrend_pre``. Must include columns for gvar, ivar, tvar, and transformed outcomes named 'ydot_pre_g{g}_t{t}' or 'ycheck_pre_g{g}_t{t}'. gvar : str Name of the cohort variable column indicating first treatment period. ivar : str Name of the unit identifier column. tvar : str Name of the time variable column. controls : list of str, optional Names of time-invariant control variable columns. vce : str, optional Variance estimation type: None (homoskedastic), 'hc3' (heteroskedasticity-robust), or 'cluster' (cluster-robust). cluster_var : str, optional Name of the cluster variable column. Required when vce='cluster'. control_strategy : str, default='not_yet_treated' Control group selection strategy: 'never_treated' uses only never-treated units; 'not_yet_treated' includes units first treated after the current period; 'all_others' uses all units not in the treatment cohort (including already-treated units). never_treated_values : list, optional Values in gvar indicating never-treated units. Defaults to [0, np.inf] and NaN values. min_obs : int, default=3 Minimum total sample size required for estimation. min_treated : int, default=1 Minimum number of treated units required. min_control : int, default=1 Minimum number of control units required. alpha : float, default=0.05 Significance level for confidence interval construction. estimator : str, default='ra' Estimation method: 'ra' (regression adjustment), 'ipwra' (inverse probability weighted regression adjustment), or 'psm' (propensity score matching). transform_type : str, default='demean' Transformation type applied to the data: 'demean' or 'detrend'. Determines the column prefix for transformed outcomes. propensity_controls : list of str, optional Control variables for the propensity score model. If None, uses the same variables as ``controls``. trim_threshold : float, default=0.01 Propensity score trimming threshold for IPWRA and PSM. se_method : str, default='analytical' Standard error method for IPWRA: 'analytical' or 'bootstrap'. n_neighbors : int, default=1 Number of nearest neighbors for PSM matching. with_replacement : bool, default=True Whether PSM matching allows replacement. caliper : float, optional Maximum propensity score distance for PSM matching. Returns ------- list of PreTreatmentEffect Estimation results for all valid (cohort, period) pairs, sorted by cohort and then by event_time (descending, so anchor point comes first). Raises ------ ValueError If required columns are missing, no valid treatment cohorts exist, or parameter values are invalid. See Also -------- transform_staggered_demean_pre : Pre-treatment demeaning transformation. transform_staggered_detrend_pre : Pre-treatment detrending transformation. test_parallel_trends : Statistical test for parallel trends assumption. Notes ----- The anchor point (t = g-1, event time e = -1) is handled specially: - ATT is set to exactly 0.0 (by construction of the transformation) - SE is set to 0.0 - is_anchor flag is set to True For pre-treatment periods, the control group is defined as: control = {units with gvar > t} ∪ {never-treated units}. """ # ========================================================================= # Input Validation # ========================================================================= required_cols = [gvar, ivar, tvar] missing = [c for c in required_cols if c not in data_transformed.columns] if missing: raise ValueError(f"Missing required columns: {missing}") if vce == 'cluster' and cluster_var is None: raise ValueError("cluster_var required when vce='cluster'") strategy_map = { 'never_treated': ControlGroupStrategy.NEVER_TREATED, 'not_yet_treated': ControlGroupStrategy.NOT_YET_TREATED, 'all_others': ControlGroupStrategy.ALL_OTHERS, 'auto': ControlGroupStrategy.AUTO, } if control_strategy not in strategy_map: raise ValueError( f"Invalid control_strategy: {control_strategy}. " f"Must be one of: {list(strategy_map.keys())}" ) strategy = strategy_map[control_strategy] # ========================================================================= # Cohort and Period Extraction # ========================================================================= if never_treated_values is None: nt_values = [0, np.inf] else: nt_values = never_treated_values cohorts = get_cohorts(data_transformed, gvar, ivar, nt_values) if len(cohorts) == 0: raise ValueError("No valid treatment cohorts found in data.") T_min = int(data_transformed[tvar].min()) # ========================================================================= # Pre-treatment Effect Estimation # ========================================================================= results = [] skipped_pairs = [] # Column prefix reflects transformation type prefix = 'ydot_pre' if transform_type == 'demean' else 'ycheck_pre' for g in cohorts: pre_periods = get_pre_treatment_periods_for_cohort(g, T_min) if len(pre_periods) == 0: warnings.warn( f"Cohort {g} has no pre-treatment periods (T_min={T_min}).", UserWarning ) continue for t in pre_periods: event_time = t - g # Always negative for pre-treatment rolling_window_size = g - t - 1 # Number of future periods ydot_col = f'{prefix}_g{g}_t{t}' # ----------------------------------------------------------------- # Handle Anchor Point (t = g-1, e = -1) # ----------------------------------------------------------------- if t == g - 1: # Anchor point: ATT = 0 by construction # Count treated and control units for reporting period_data = data_transformed[data_transformed[tvar] == t] n_treat = (period_data[gvar] == g).sum() try: unit_control_mask = get_valid_control_units( period_data, gvar, ivar, cohort=g, period=t, strategy=strategy, never_treated_values=nt_values, is_pre_treatment=True, ) control_mask = period_data[ivar].map(unit_control_mask).fillna(False).astype(bool) n_control = control_mask.sum() except Exception: n_control = 0 results.append(PreTreatmentEffect( cohort=int(g), period=int(t), event_time=int(event_time), att=0.0, se=0.0, ci_lower=0.0, ci_upper=0.0, t_stat=np.nan, # Undefined (0/0) pvalue=np.nan, n_treated=int(n_treat), n_control=int(n_control), is_anchor=True, rolling_window_size=0, df_inference=int(n_treat + n_control - 2) if n_treat + n_control > 2 else 0, )) continue # ----------------------------------------------------------------- # Check for Transformed Outcome Column # ----------------------------------------------------------------- if ydot_col not in data_transformed.columns: skipped_pairs.append((g, t, 'missing_transform_column')) continue # ----------------------------------------------------------------- # Extract Period Cross-Section # ----------------------------------------------------------------- period_data = data_transformed[data_transformed[tvar] == t].copy() if len(period_data) == 0: skipped_pairs.append((g, t, 'no_data_in_period')) continue # ----------------------------------------------------------------- # Identify Valid Control Units (Pre-treatment) # ----------------------------------------------------------------- try: unit_control_mask = get_valid_control_units( period_data, gvar, ivar, cohort=g, period=t, strategy=strategy, never_treated_values=nt_values, is_pre_treatment=True, ) except Exception as e: skipped_pairs.append((g, t, f'control_mask_error: {e}')) continue control_mask = period_data[ivar].map(unit_control_mask).fillna(False).astype(bool) # ----------------------------------------------------------------- # Construct Estimation Sample # ----------------------------------------------------------------- # For pre-treatment: treated units are those in cohort g # (even though they haven't been treated yet at time t) treat_mask = (period_data[gvar] == g) sample_mask = treat_mask | control_mask n_treat = treat_mask.sum() n_control = control_mask.sum() n_total = sample_mask.sum() if n_total < min_obs: skipped_pairs.append((g, t, f'insufficient_total: {n_total}<{min_obs}')) continue if n_treat < min_treated: skipped_pairs.append((g, t, f'insufficient_treated: {n_treat}<{min_treated}')) continue if n_control < min_control: skipped_pairs.append((g, t, f'insufficient_control: {n_control}<{min_control}')) continue sample_data = period_data[sample_mask].copy() sample_data['_D_treat'] = (sample_data[gvar] == g).astype(int) # ----------------------------------------------------------------- # Check for Valid Transformed Outcomes # ----------------------------------------------------------------- valid_outcome_mask = sample_data[ydot_col].notna() if valid_outcome_mask.sum() < min_obs: skipped_pairs.append((g, t, f'insufficient_valid_outcomes: {valid_outcome_mask.sum()}<{min_obs}')) continue # ----------------------------------------------------------------- # Run Estimator # ----------------------------------------------------------------- try: if estimator == 'ra': est_result = run_ols_regression( data=sample_data, y=ydot_col, d='_D_treat', controls=controls, vce=vce, cluster_var=cluster_var, alpha=alpha, ) elif estimator == 'ipwra': est_result = _estimate_single_effect_ipwra( data=sample_data, y=ydot_col, d='_D_treat', controls=controls or [], propensity_controls=propensity_controls or controls or [], trim_threshold=trim_threshold, se_method=se_method, alpha=alpha, ) elif estimator == 'psm': est_result = _estimate_single_effect_psm( data=sample_data, y=ydot_col, d='_D_treat', propensity_controls=propensity_controls or controls or [], n_neighbors=n_neighbors, with_replacement=with_replacement, caliper=caliper, trim_threshold=trim_threshold, alpha=alpha, ) else: raise ValueError(f"Unknown estimator: {estimator}") except Exception as e: skipped_pairs.append((g, t, f'{estimator}_error: {e}')) continue results.append(PreTreatmentEffect( cohort=int(g), period=int(t), event_time=int(event_time), att=est_result['att'], se=est_result['se'], ci_lower=est_result['ci_lower'], ci_upper=est_result['ci_upper'], t_stat=est_result['t_stat'], pvalue=est_result['pvalue'], n_treated=int(n_treat), n_control=int(n_control), is_anchor=False, rolling_window_size=rolling_window_size, df_inference=int(est_result.get('df_inference', 0)) if not np.isnan(est_result.get('df_inference', 0)) else 0, )) # ========================================================================= # Reporting # ========================================================================= if skipped_pairs: n_skipped = len(skipped_pairs) n_total_pairs = sum(len(get_pre_treatment_periods_for_cohort(g, T_min)) for g in cohorts) warnings.warn( f"Skipped {n_skipped}/{n_total_pairs} pre-treatment (cohort, period) pairs " f"due to insufficient data or errors.", UserWarning ) # Sort by cohort, then by event_time descending (anchor first) results.sort(key=lambda x: (x.cohort, -x.event_time)) return results
[docs] def pre_treatment_effects_to_dataframe( results: list[PreTreatmentEffect], ) -> pd.DataFrame: """ Convert a list of PreTreatmentEffect objects to a pandas DataFrame. Parameters ---------- results : list of PreTreatmentEffect Estimation results from ``estimate_pre_treatment_effects``. Returns ------- pd.DataFrame DataFrame with columns: cohort, period, event_time, att, se, ci_lower, ci_upper, t_stat, pvalue, n_treated, n_control, is_anchor, rolling_window_size. Returns an empty DataFrame with appropriate columns if the input list is empty. See Also -------- estimate_pre_treatment_effects : Estimate pre-treatment ATT. PreTreatmentEffect : Container class for individual effect estimates. """ if len(results) == 0: return pd.DataFrame(columns=[ 'cohort', 'period', 'event_time', 'att', 'se', 'ci_lower', 'ci_upper', 't_stat', 'pvalue', 'n_treated', 'n_control', 'is_anchor', 'rolling_window_size', 'df_inference' ]) return pd.DataFrame([ { 'cohort': r.cohort, 'period': r.period, 'event_time': r.event_time, 'att': r.att, 'se': r.se, 'ci_lower': r.ci_lower, 'ci_upper': r.ci_upper, 't_stat': r.t_stat, 'pvalue': r.pvalue, 'n_treated': r.n_treated, 'n_control': r.n_control, 'is_anchor': r.is_anchor, 'rolling_window_size': r.rolling_window_size, 'df_inference': r.df_inference, } for r in results ])