Source code for lwdid.staggered.parallel_trends

"""
Statistical tests for the parallel trends assumption in staggered DiD.

This module provides hypothesis tests to assess the parallel trends assumption
using pre-treatment ATT estimates. Under the null hypothesis of parallel
trends and no anticipation, pre-treatment ATT estimates should be
statistically indistinguishable from zero.

Two complementary testing approaches are implemented:

1. **Individual t-tests**: Test H0: ATT_e = 0 for each pre-treatment event
   time e < -1, excluding the anchor point which is zero by construction.

2. **Joint F-test (or Wald test)**: Test H0: all pre-treatment ATT = 0
   simultaneously. This is the primary diagnostic for parallel trends.

The anchor point at event time e = -1 is excluded from testing because
it is set to zero by construction of the rolling transformation.

Notes
-----
The joint F-test assumes independence across pre-treatment periods. For
settings with substantial serial correlation, the Wald (chi-squared) test
may be preferred. Both tests are asymptotically valid under standard
regularity conditions.

Rejection of the null hypothesis suggests potential violation of the
parallel trends assumption. However, failure to reject does not prove
parallel trends holds, as the test may simply lack statistical power.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import scipy.stats as stats

if TYPE_CHECKING:
    from .estimation_pre import PreTreatmentEffect


[docs] @dataclass class ParallelTrendsTestResult: """ Container for parallel trends test results. Stores both individual t-tests for each pre-treatment period and the joint F-test for the null hypothesis that all pre-treatment ATT estimates equal zero. Attributes ---------- individual_tests : pd.DataFrame DataFrame with columns: event_time, att, se, t_stat, pvalue. Contains test results for each pre-treatment period (excluding anchor point). joint_f_stat : float F-statistic for joint test H0: all pre-treatment ATT = 0. joint_pvalue : float P-value for joint F-test. joint_df1 : int Numerator degrees of freedom (number of pre-treatment periods included in the test). joint_df2 : int Denominator degrees of freedom. n_pre_periods : int Number of pre-treatment periods included in test. excluded_periods : list Event times excluded due to missing SE, anchor point, or other issues. reject_null : bool True if joint test rejects H0 at the specified alpha level. alpha : float Significance level used for the test. Notes ----- Rejection of the null hypothesis suggests potential violation of the parallel trends assumption. However, failure to reject does not prove parallel trends holds - it may simply reflect low power. """ individual_tests: pd.DataFrame joint_f_stat: float joint_pvalue: float joint_df1: int joint_df2: int n_pre_periods: int excluded_periods: list = field(default_factory=list) reject_null: bool = False alpha: float = 0.05
def _compute_individual_t_tests( pre_treatment_effects: list[PreTreatmentEffect], alpha: float = 0.05, ) -> tuple[pd.DataFrame, list[int]]: """ Compute individual t-tests for each pre-treatment ATT estimate. Tests H0: ATT_e = 0 for each pre-treatment event time, using a t-distribution with degrees of freedom based on sample sizes. Parameters ---------- pre_treatment_effects : list of PreTreatmentEffect Pre-treatment effect estimates from estimate_pre_treatment_effects. alpha : float, default=0.05 Significance level for determining the 'significant' flag. Returns ------- individual_tests : pd.DataFrame DataFrame with columns: event_time, cohort, period, att, se, t_stat, pvalue, significant, n_treated, n_control. Each row represents test results for one pre-treatment period. excluded_periods : list of int Event times excluded from testing due to anchor point status, missing standard errors, or invalid ATT values. """ test_results = [] excluded_periods = [] for effect in pre_treatment_effects: # Anchor points are zero by construction and provide no test information. if effect.is_anchor: excluded_periods.append(effect.event_time) continue # Invalid SE prevents reliable t-test computation. if np.isnan(effect.se) or effect.se <= 0: excluded_periods.append(effect.event_time) continue # Missing ATT indicates estimation failure for this period. if np.isnan(effect.att): excluded_periods.append(effect.event_time) continue t_stat = effect.att / effect.se # Approximate df using pooled sample size; default to 1000 for # asymptotic approximation when sample is too small. df = effect.n_treated + effect.n_control - 2 if df <= 0: df = 1000 pvalue = 2 * (1 - stats.t.cdf(abs(t_stat), df)) test_results.append({ 'event_time': effect.event_time, 'cohort': effect.cohort, 'period': effect.period, 'att': effect.att, 'se': effect.se, 't_stat': t_stat, 'pvalue': pvalue, 'significant': pvalue < alpha, 'n_treated': effect.n_treated, 'n_control': effect.n_control, }) if len(test_results) == 0: return pd.DataFrame(columns=[ 'event_time', 'cohort', 'period', 'att', 'se', 't_stat', 'pvalue', 'significant', 'n_treated', 'n_control' ]), excluded_periods return pd.DataFrame(test_results), excluded_periods def _compute_joint_f_test( individual_tests: pd.DataFrame, alpha: float = 0.05, ) -> tuple[float, float, int, int]: """ Compute joint F-test for H0: all pre-treatment ATT = 0. Uses a Wald-type F-statistic computed as the average of squared t-statistics across pre-treatment periods. Parameters ---------- individual_tests : pd.DataFrame DataFrame with columns: att, se, t_stat, n_treated, n_control. alpha : float, default=0.05 Significance level for hypothesis testing. Currently unused but retained for API consistency with related functions. Returns ------- f_stat : float F-statistic computed as (1/q) * sum(t_j^2). pvalue : float P-value from F(q, df2) distribution. df1 : int Numerator degrees of freedom (number of restrictions q). df2 : int Denominator degrees of freedom (based on average sample size). Notes ----- This is a simplified F-test that assumes independence across pre-treatment periods. For settings with substantial correlation between period-specific estimates, a test based on the full variance-covariance matrix would provide more accurate inference. """ if len(individual_tests) == 0: return np.nan, np.nan, 0, 0 q = len(individual_tests) t_stats = individual_tests['t_stat'].values sum_t2 = np.sum(t_stats ** 2) # F-statistic as average of squared t-statistics follows F(q, df2). f_stat = sum_t2 / q df1 = q # Approximate denominator df using average sample size across periods. avg_n = individual_tests['n_treated'].mean() + individual_tests['n_control'].mean() df2 = max(int(avg_n - 2), 1) pvalue = 1 - stats.f.cdf(f_stat, df1, df2) return f_stat, pvalue, df1, df2 def _compute_joint_wald_test( individual_tests: pd.DataFrame, alpha: float = 0.05, ) -> tuple[float, float, int]: """ Compute joint Wald test for H0: all pre-treatment ATT = 0. The Wald statistic is the sum of squared t-statistics, which follows a chi-squared distribution with q degrees of freedom under the null. Parameters ---------- individual_tests : pd.DataFrame DataFrame with columns: att, se, t_stat. alpha : float, default=0.05 Significance level for hypothesis testing. Currently unused but retained for API consistency with related functions. Returns ------- wald_stat : float Wald statistic computed as sum(t_j^2). pvalue : float P-value from chi-squared(q) distribution. df : int Degrees of freedom (number of pre-treatment periods tested). Notes ----- The Wald test is asymptotically equivalent to the F-test but uses the chi-squared distribution rather than the F-distribution. This test may be preferred when sample sizes are large or when comparing results with other software implementations. """ if len(individual_tests) == 0: return np.nan, np.nan, 0 q = len(individual_tests) t_stats = individual_tests['t_stat'].values # Sum of squared t-statistics follows chi-squared(q) under the null. wald_stat = np.sum(t_stats ** 2) pvalue = 1 - stats.chi2.cdf(wald_stat, q) return wald_stat, pvalue, q # Alternative function name following common API naming conventions. parallel_trends_test = run_parallel_trends_test