"""
Pre-treatment period ATT estimation for staggered difference-in-differences.
This module implements treatment effect estimation for pre-treatment periods
in staggered adoption settings. Pre-treatment ATT estimates are used for:
1. Event study visualization with pre-treatment effects
2. Parallel trends assumption testing
3. Detection of anticipation effects or dynamic selection
Under the parallel trends assumption and no anticipation, pre-treatment
ATT estimates should be statistically indistinguishable from zero.
The estimation methodology uses rolling transformations:
- For each cohort g and pre-treatment period t < g, the transformation
uses future pre-treatment periods {t+1, ..., g-1} to compute baselines.
- The anchor point at t = g-1 (event time e = -1) is set to exactly 0
by construction, serving as the reference for pre-treatment dynamics.
- Control groups are dynamically defined based on treatment timing:
units first treated after period t plus never-treated units.
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
import numpy as np
import pandas as pd
from .control_groups import (
ControlGroupStrategy,
get_valid_control_units,
)
from .transformations import get_cohorts
from .transformations_pre import get_pre_treatment_periods_for_cohort
from .estimation import run_ols_regression, _estimate_single_effect_ipwra, _estimate_single_effect_psm
[docs]
@dataclass(frozen=True)
class PreTreatmentEffect:
"""
Container for a pre-treatment period ATT estimate.
Stores the ATT for treatment cohort g at pre-treatment period t < g,
along with inference statistics. Under parallel trends, these estimates
should be approximately zero.
Attributes
----------
cohort : int
Treatment cohort identifier (first treatment period g).
period : int
Calendar time period t (where t < g).
event_time : int
Event time relative to treatment onset (e = t - g, always negative).
att : float
Estimated pre-treatment ATT (should be ~0 under parallel trends).
se : float
Standard error of the ATT estimate.
ci_lower : float
Lower bound of confidence interval.
ci_upper : float
Upper bound of confidence interval.
t_stat : float
t-statistic for testing H0: ATT = 0.
pvalue : float
Two-sided p-value.
n_treated : int
Number of treated units in estimation sample.
n_control : int
Number of control units in estimation sample.
is_anchor : bool
True if this is the anchor point (t = g-1, e = -1).
Anchor points have ATT = 0, SE = 0 by convention.
rolling_window_size : int
Number of future periods used in transformation {t+1, ..., g-1}.
Notes
-----
The dataclass is frozen (immutable) to ensure estimation results
cannot be accidentally modified after creation.
"""
cohort: int
period: int
event_time: int
att: float
se: float
ci_lower: float
ci_upper: float
t_stat: float
pvalue: float
n_treated: int
n_control: int
is_anchor: bool = False
rolling_window_size: int = 0
df_inference: int = 0
[docs]
def estimate_pre_treatment_effects(
data_transformed: pd.DataFrame,
gvar: str,
ivar: str,
tvar: str,
controls: list[str] | None = None,
vce: str | None = None,
cluster_var: str | None = None,
control_strategy: str = 'not_yet_treated',
never_treated_values: list | None = None,
min_obs: int = 3,
min_treated: int = 1,
min_control: int = 1,
alpha: float = 0.05,
estimator: str = 'ra',
transform_type: str = 'demean',
propensity_controls: list[str] | None = None,
trim_threshold: float = 0.01,
se_method: str = 'analytical',
n_neighbors: int = 1,
with_replacement: bool = True,
caliper: float | None = None,
) -> list[PreTreatmentEffect]:
"""
Estimate pre-treatment effects for all valid cohort-period pairs.
Iterates over all treatment cohorts and their pre-treatment periods,
estimating the ATT for each (cohort, period) combination. Under the
parallel trends assumption, these estimates should be approximately zero.
Parameters
----------
data_transformed : pd.DataFrame
Panel data containing pre-treatment transformed outcome columns
generated by ``transform_staggered_demean_pre`` or
``transform_staggered_detrend_pre``. Must include columns for
gvar, ivar, tvar, and transformed outcomes named
'ydot_pre_g{g}_t{t}' or 'ycheck_pre_g{g}_t{t}'.
gvar : str
Name of the cohort variable column indicating first treatment period.
ivar : str
Name of the unit identifier column.
tvar : str
Name of the time variable column.
controls : list of str, optional
Names of time-invariant control variable columns.
vce : str, optional
Variance estimation type: None (homoskedastic), 'hc3'
(heteroskedasticity-robust), or 'cluster' (cluster-robust).
cluster_var : str, optional
Name of the cluster variable column. Required when vce='cluster'.
control_strategy : str, default='not_yet_treated'
Control group selection strategy: 'never_treated' uses only
never-treated units; 'not_yet_treated' includes units first
treated after the current period; 'all_others' uses all units
not in the treatment cohort (including already-treated units).
never_treated_values : list, optional
Values in gvar indicating never-treated units. Defaults to
[0, np.inf] and NaN values.
min_obs : int, default=3
Minimum total sample size required for estimation.
min_treated : int, default=1
Minimum number of treated units required.
min_control : int, default=1
Minimum number of control units required.
alpha : float, default=0.05
Significance level for confidence interval construction.
estimator : str, default='ra'
Estimation method: 'ra' (regression adjustment), 'ipwra'
(inverse probability weighted regression adjustment), or
'psm' (propensity score matching).
transform_type : str, default='demean'
Transformation type applied to the data: 'demean' or 'detrend'.
Determines the column prefix for transformed outcomes.
propensity_controls : list of str, optional
Control variables for the propensity score model. If None,
uses the same variables as ``controls``.
trim_threshold : float, default=0.01
Propensity score trimming threshold for IPWRA and PSM.
se_method : str, default='analytical'
Standard error method for IPWRA: 'analytical' or 'bootstrap'.
n_neighbors : int, default=1
Number of nearest neighbors for PSM matching.
with_replacement : bool, default=True
Whether PSM matching allows replacement.
caliper : float, optional
Maximum propensity score distance for PSM matching.
Returns
-------
list of PreTreatmentEffect
Estimation results for all valid (cohort, period) pairs,
sorted by cohort and then by event_time (descending, so
anchor point comes first).
Raises
------
ValueError
If required columns are missing, no valid treatment cohorts
exist, or parameter values are invalid.
See Also
--------
transform_staggered_demean_pre : Pre-treatment demeaning transformation.
transform_staggered_detrend_pre : Pre-treatment detrending transformation.
test_parallel_trends : Statistical test for parallel trends assumption.
Notes
-----
The anchor point (t = g-1, event time e = -1) is handled specially:
- ATT is set to exactly 0.0 (by construction of the transformation)
- SE is set to 0.0
- is_anchor flag is set to True
For pre-treatment periods, the control group is defined as:
control = {units with gvar > t} ∪ {never-treated units}.
"""
# =========================================================================
# Input Validation
# =========================================================================
required_cols = [gvar, ivar, tvar]
missing = [c for c in required_cols if c not in data_transformed.columns]
if missing:
raise ValueError(f"Missing required columns: {missing}")
if vce == 'cluster' and cluster_var is None:
raise ValueError("cluster_var required when vce='cluster'")
strategy_map = {
'never_treated': ControlGroupStrategy.NEVER_TREATED,
'not_yet_treated': ControlGroupStrategy.NOT_YET_TREATED,
'all_others': ControlGroupStrategy.ALL_OTHERS,
'auto': ControlGroupStrategy.AUTO,
}
if control_strategy not in strategy_map:
raise ValueError(
f"Invalid control_strategy: {control_strategy}. "
f"Must be one of: {list(strategy_map.keys())}"
)
strategy = strategy_map[control_strategy]
# =========================================================================
# Cohort and Period Extraction
# =========================================================================
if never_treated_values is None:
nt_values = [0, np.inf]
else:
nt_values = never_treated_values
cohorts = get_cohorts(data_transformed, gvar, ivar, nt_values)
if len(cohorts) == 0:
raise ValueError("No valid treatment cohorts found in data.")
T_min = int(data_transformed[tvar].min())
# =========================================================================
# Pre-treatment Effect Estimation
# =========================================================================
results = []
skipped_pairs = []
# Column prefix reflects transformation type
prefix = 'ydot_pre' if transform_type == 'demean' else 'ycheck_pre'
for g in cohorts:
pre_periods = get_pre_treatment_periods_for_cohort(g, T_min)
if len(pre_periods) == 0:
warnings.warn(
f"Cohort {g} has no pre-treatment periods (T_min={T_min}).",
UserWarning
)
continue
for t in pre_periods:
event_time = t - g # Always negative for pre-treatment
rolling_window_size = g - t - 1 # Number of future periods
ydot_col = f'{prefix}_g{g}_t{t}'
# -----------------------------------------------------------------
# Handle Anchor Point (t = g-1, e = -1)
# -----------------------------------------------------------------
if t == g - 1:
# Anchor point: ATT = 0 by construction
# Count treated and control units for reporting
period_data = data_transformed[data_transformed[tvar] == t]
n_treat = (period_data[gvar] == g).sum()
try:
unit_control_mask = get_valid_control_units(
period_data, gvar, ivar,
cohort=g, period=t,
strategy=strategy,
never_treated_values=nt_values,
is_pre_treatment=True,
)
control_mask = period_data[ivar].map(unit_control_mask).fillna(False).astype(bool)
n_control = control_mask.sum()
except Exception:
n_control = 0
results.append(PreTreatmentEffect(
cohort=int(g),
period=int(t),
event_time=int(event_time),
att=0.0,
se=0.0,
ci_lower=0.0,
ci_upper=0.0,
t_stat=np.nan, # Undefined (0/0)
pvalue=np.nan,
n_treated=int(n_treat),
n_control=int(n_control),
is_anchor=True,
rolling_window_size=0,
df_inference=int(n_treat + n_control - 2) if n_treat + n_control > 2 else 0,
))
continue
# -----------------------------------------------------------------
# Check for Transformed Outcome Column
# -----------------------------------------------------------------
if ydot_col not in data_transformed.columns:
skipped_pairs.append((g, t, 'missing_transform_column'))
continue
# -----------------------------------------------------------------
# Extract Period Cross-Section
# -----------------------------------------------------------------
period_data = data_transformed[data_transformed[tvar] == t].copy()
if len(period_data) == 0:
skipped_pairs.append((g, t, 'no_data_in_period'))
continue
# -----------------------------------------------------------------
# Identify Valid Control Units (Pre-treatment)
# -----------------------------------------------------------------
try:
unit_control_mask = get_valid_control_units(
period_data, gvar, ivar,
cohort=g, period=t,
strategy=strategy,
never_treated_values=nt_values,
is_pre_treatment=True,
)
except Exception as e:
skipped_pairs.append((g, t, f'control_mask_error: {e}'))
continue
control_mask = period_data[ivar].map(unit_control_mask).fillna(False).astype(bool)
# -----------------------------------------------------------------
# Construct Estimation Sample
# -----------------------------------------------------------------
# For pre-treatment: treated units are those in cohort g
# (even though they haven't been treated yet at time t)
treat_mask = (period_data[gvar] == g)
sample_mask = treat_mask | control_mask
n_treat = treat_mask.sum()
n_control = control_mask.sum()
n_total = sample_mask.sum()
if n_total < min_obs:
skipped_pairs.append((g, t, f'insufficient_total: {n_total}<{min_obs}'))
continue
if n_treat < min_treated:
skipped_pairs.append((g, t, f'insufficient_treated: {n_treat}<{min_treated}'))
continue
if n_control < min_control:
skipped_pairs.append((g, t, f'insufficient_control: {n_control}<{min_control}'))
continue
sample_data = period_data[sample_mask].copy()
sample_data['_D_treat'] = (sample_data[gvar] == g).astype(int)
# -----------------------------------------------------------------
# Check for Valid Transformed Outcomes
# -----------------------------------------------------------------
valid_outcome_mask = sample_data[ydot_col].notna()
if valid_outcome_mask.sum() < min_obs:
skipped_pairs.append((g, t, f'insufficient_valid_outcomes: {valid_outcome_mask.sum()}<{min_obs}'))
continue
# -----------------------------------------------------------------
# Run Estimator
# -----------------------------------------------------------------
try:
if estimator == 'ra':
est_result = run_ols_regression(
data=sample_data,
y=ydot_col,
d='_D_treat',
controls=controls,
vce=vce,
cluster_var=cluster_var,
alpha=alpha,
)
elif estimator == 'ipwra':
est_result = _estimate_single_effect_ipwra(
data=sample_data,
y=ydot_col,
d='_D_treat',
controls=controls or [],
propensity_controls=propensity_controls or controls or [],
trim_threshold=trim_threshold,
se_method=se_method,
alpha=alpha,
)
elif estimator == 'psm':
est_result = _estimate_single_effect_psm(
data=sample_data,
y=ydot_col,
d='_D_treat',
propensity_controls=propensity_controls or controls or [],
n_neighbors=n_neighbors,
with_replacement=with_replacement,
caliper=caliper,
trim_threshold=trim_threshold,
alpha=alpha,
)
else:
raise ValueError(f"Unknown estimator: {estimator}")
except Exception as e:
skipped_pairs.append((g, t, f'{estimator}_error: {e}'))
continue
results.append(PreTreatmentEffect(
cohort=int(g),
period=int(t),
event_time=int(event_time),
att=est_result['att'],
se=est_result['se'],
ci_lower=est_result['ci_lower'],
ci_upper=est_result['ci_upper'],
t_stat=est_result['t_stat'],
pvalue=est_result['pvalue'],
n_treated=int(n_treat),
n_control=int(n_control),
is_anchor=False,
rolling_window_size=rolling_window_size,
df_inference=int(est_result.get('df_inference', 0)) if not np.isnan(est_result.get('df_inference', 0)) else 0,
))
# =========================================================================
# Reporting
# =========================================================================
if skipped_pairs:
n_skipped = len(skipped_pairs)
n_total_pairs = sum(len(get_pre_treatment_periods_for_cohort(g, T_min)) for g in cohorts)
warnings.warn(
f"Skipped {n_skipped}/{n_total_pairs} pre-treatment (cohort, period) pairs "
f"due to insufficient data or errors.",
UserWarning
)
# Sort by cohort, then by event_time descending (anchor first)
results.sort(key=lambda x: (x.cohort, -x.event_time))
return results
[docs]
def pre_treatment_effects_to_dataframe(
results: list[PreTreatmentEffect],
) -> pd.DataFrame:
"""
Convert a list of PreTreatmentEffect objects to a pandas DataFrame.
Parameters
----------
results : list of PreTreatmentEffect
Estimation results from ``estimate_pre_treatment_effects``.
Returns
-------
pd.DataFrame
DataFrame with columns: cohort, period, event_time, att, se,
ci_lower, ci_upper, t_stat, pvalue, n_treated, n_control,
is_anchor, rolling_window_size.
Returns an empty DataFrame with appropriate columns if the input
list is empty.
See Also
--------
estimate_pre_treatment_effects : Estimate pre-treatment ATT.
PreTreatmentEffect : Container class for individual effect estimates.
"""
if len(results) == 0:
return pd.DataFrame(columns=[
'cohort', 'period', 'event_time', 'att', 'se',
'ci_lower', 'ci_upper', 't_stat', 'pvalue',
'n_treated', 'n_control', 'is_anchor', 'rolling_window_size',
'df_inference'
])
return pd.DataFrame([
{
'cohort': r.cohort,
'period': r.period,
'event_time': r.event_time,
'att': r.att,
'se': r.se,
'ci_lower': r.ci_lower,
'ci_upper': r.ci_upper,
't_stat': r.t_stat,
'pvalue': r.pvalue,
'n_treated': r.n_treated,
'n_control': r.n_control,
'is_anchor': r.is_anchor,
'rolling_window_size': r.rolling_window_size,
'df_inference': r.df_inference,
}
for r in results
])