"""
Core estimation interface for difference-in-differences with panel data.
This module provides the main entry point for LWDID estimation, supporting
three methodological scenarios based on the rolling transformation approach:
1. **Small-sample common timing**: Exact t-based inference under classical
linear model (CLM) assumptions when the number of cross-sectional units
is small.
2. **Large-sample common timing**: Asymptotic inference with
heteroskedasticity-robust standard errors for moderate to large samples.
3. **Staggered adoption**: Cohort-time specific effect estimation with
flexible control group strategies (never-treated or not-yet-treated)
for settings where treatment timing varies.
The method applies unit-specific time-series transformations that remove
pre-treatment patterns, converting the panel DiD problem into a cross-sectional
treatment effects problem. Under no anticipation and parallel trends
assumptions, standard estimators can be applied to the transformed outcomes.
Notes
-----
Four transformation methods are available:
- **Demeaning** ('demean'): Subtracts the unit-specific pre-treatment mean.
Requires at least one pre-treatment period.
- **Detrending** ('detrend'): Removes unit-specific linear time trends.
Requires at least two pre-treatment periods.
- **Seasonal demeaning** ('demeanq'): Removes unit-specific mean with seasonal
fixed effects. Requires sufficient pre-treatment observations for seasonal
pattern estimation.
- **Seasonal detrending** ('detrendq'): Removes unit-specific linear trends
with seasonal effects. Requires at least two pre-treatment periods plus
adequate seasonal coverage.
Four estimation methods are supported:
- **RA**: Regression adjustment via OLS on transformed outcomes.
- **IPW**: Inverse probability weighting using propensity scores.
- **IPWRA**: Doubly robust combining IPW with regression adjustment.
- **PSM**: Propensity score matching with nearest-neighbor matching.
For staggered adoption, transformations are applied separately for each
treatment cohort using cohort-specific pre-treatment periods.
"""
from __future__ import annotations
import logging
import math
import random
import warnings
import numpy as np
import pandas as pd
import scipy.stats
from . import estimation, transformations, validation
from .exceptions import (
InsufficientDataError,
InvalidParameterError,
NoNeverTreatedError,
RandomizationError,
)
from .randomization import randomization_inference
from .results import LWDIDResults
from .staggered.estimators import (
IPWResult,
IPWRAResult,
PSMResult,
estimate_ipw,
estimate_ipwra,
estimate_psm,
)
from .validation import is_never_treated, validate_staggered_data
logger = logging.getLogger('lwdid')
def _generate_ri_seed() -> int:
"""
Generate a random seed for randomization inference.
Combines two independent uniform random draws to produce a unique seed,
reducing the probability of seed collisions across repeated calls.
Returns
-------
int
Random seed in the range [1, 1001000].
"""
return math.ceil(random.random() * 1e6 + 1000 * random.random())
def _validate_psm_params(
n_neighbors: int,
caliper: float | None,
with_replacement: bool,
match_order: str = 'data',
) -> None:
"""
Validate propensity score matching parameters.
Checks type and value constraints for PSM-specific parameters before
estimation proceeds.
Parameters
----------
n_neighbors : int
Number of nearest neighbors for matching. Must be >= 1.
caliper : float or None
Maximum propensity score distance for valid matches. Must be positive
and finite if specified.
with_replacement : bool
Whether control units can be matched to multiple treated units.
match_order : {'data', 'random', 'largest', 'smallest'}, default='data'
Order in which treated units are processed for matching without
replacement.
Raises
------
TypeError
If parameter types do not match expected types.
ValueError
If parameter values are outside valid ranges.
"""
if not isinstance(n_neighbors, (int, np.integer)):
raise TypeError(
f"n_neighbors must be an integer, got {type(n_neighbors).__name__}"
)
if n_neighbors < 1:
raise ValueError(
f"n_neighbors must be >= 1, got {n_neighbors}"
)
if caliper is not None:
if not isinstance(caliper, (int, float, np.number)):
raise TypeError(
f"caliper must be numeric or None, got {type(caliper).__name__}"
)
# Explicit finiteness check required because np.nan <= 0 returns False.
if not np.isfinite(caliper) or caliper <= 0:
raise ValueError(
f"caliper must be a finite positive number, got {caliper}"
)
if not isinstance(with_replacement, bool):
raise TypeError(
f"with_replacement must be bool, got {type(with_replacement).__name__}"
)
valid_match_orders = {'data', 'random', 'largest', 'smallest'}
if not isinstance(match_order, str):
raise TypeError(
f"match_order must be str, got {type(match_order).__name__}"
)
if match_order not in valid_match_orders:
raise ValueError(
f"match_order must be one of {valid_match_orders}, got '{match_order}'"
)
def _convert_ipw_result_to_dict(
ipw_result: IPWResult,
alpha: float,
vce: str | None,
cluster_var: str | None,
controls: list[str] | None,
ps_controls: list[str] | None = None,
) -> dict:
"""
Convert IPWResult to the standard results dictionary format.
Transforms the IPW estimator output into the unified dictionary structure
expected by LWDIDResults.
Parameters
----------
ipw_result : IPWResult
Result object from estimate_ipw().
alpha : float
Significance level for confidence intervals.
vce : str or None
Variance estimator type for metadata storage.
cluster_var : str or None
Cluster variable name for metadata storage.
controls : list of str or None
Control variable names for outcome model.
ps_controls : list of str or None
Control variable names for propensity score model.
Returns
-------
dict
Results dictionary compatible with LWDIDResults constructor.
"""
# IPW df uses n - 2: intercept + treatment indicator.
# PS model parameters do not enter the ATT variance formula.
df_val = ipw_result.n_treated + ipw_result.n_control - 2
return {
'att': ipw_result.att,
'se_att': ipw_result.se,
't_stat': ipw_result.t_stat,
'pvalue': ipw_result.pvalue,
'ci_lower': ipw_result.ci_lower,
'ci_upper': ipw_result.ci_upper,
'nobs': ipw_result.n_treated + ipw_result.n_control,
'df_resid': df_val,
'df_inference': df_val,
'vce_type': vce if vce is not None else 'ipw',
'cluster_var': cluster_var,
'n_clusters': None,
'controls_used': controls is not None and len(controls) > 0,
'controls': controls if controls else [],
'controls_spec': None,
'n_treated_sample': ipw_result.n_treated,
'n_control_sample': ipw_result.n_control,
'params': None,
'bse': None,
'vcov': None,
'resid': None,
'diagnostics': ipw_result.diagnostics if hasattr(ipw_result, 'diagnostics') else None,
'weights_cv': ipw_result.weights_cv,
'propensity_scores': ipw_result.propensity_scores,
'estimator': 'ipw',
}
def _convert_ipwra_result_to_dict(
ipwra_result: IPWRAResult,
alpha: float,
vce: str | None,
cluster_var: str | None,
controls: list[str] | None,
) -> dict:
"""
Convert IPWRAResult to the standard results dictionary format.
Transforms the doubly robust estimator output into the unified dictionary
structure expected by LWDIDResults.
Parameters
----------
ipwra_result : IPWRAResult
Result object from estimate_ipwra().
alpha : float
Significance level for confidence intervals.
vce : str or None
Variance estimator type for metadata storage.
cluster_var : str or None
Cluster variable name for metadata storage.
controls : list of str or None
Control variable names for outcome model.
Returns
-------
dict
Results dictionary compatible with LWDIDResults constructor.
"""
# CV requires at least 2 observations for meaningful variance.
weights = getattr(ipwra_result, 'weights', None)
if weights is not None and len(weights) > 1:
weights_mean = np.mean(weights)
weights_std = np.std(weights, ddof=1)
weights_cv = weights_std / weights_mean if weights_mean > 0 else np.nan
elif weights is not None and len(weights) == 1:
# Single observation has zero variance by definition.
weights_cv = 0.0
else:
weights_cv = np.nan
# IPWRA df: n - k where k = intercept + treatment + outcome model controls.
n_controls = len(controls) if controls else 0
n_params = 2 + n_controls
df_val = ipwra_result.n_treated + ipwra_result.n_control - n_params
return {
'att': ipwra_result.att,
'se_att': ipwra_result.se,
't_stat': ipwra_result.t_stat,
'pvalue': ipwra_result.pvalue,
'ci_lower': ipwra_result.ci_lower,
'ci_upper': ipwra_result.ci_upper,
'nobs': ipwra_result.n_treated + ipwra_result.n_control,
'df_resid': df_val,
'df_inference': df_val,
'vce_type': vce if vce is not None else 'ipwra',
'cluster_var': cluster_var,
'n_clusters': None,
'controls_used': controls is not None and len(controls) > 0,
'controls': controls if controls else [],
'controls_spec': None,
'n_treated_sample': ipwra_result.n_treated,
'n_control_sample': ipwra_result.n_control,
'params': None,
'bse': None,
'vcov': None,
'resid': None,
'diagnostics': ipwra_result.diagnostics if hasattr(ipwra_result, 'diagnostics') else None,
'weights_cv': weights_cv,
'propensity_scores': ipwra_result.propensity_scores,
'estimator': 'ipwra',
}
def _convert_psm_result_to_dict(
psm_result: PSMResult,
alpha: float,
vce: str | None,
cluster_var: str | None,
controls: list[str] | None,
) -> dict:
"""
Convert PSMResult to the standard results dictionary format.
Transforms the propensity score matching output into the unified dictionary
structure expected by LWDIDResults.
Parameters
----------
psm_result : PSMResult
Result object from estimate_psm().
alpha : float
Significance level for confidence intervals.
vce : str or None
Variance estimator type for metadata storage.
cluster_var : str or None
Cluster variable name for metadata storage.
controls : list of str or None
Control variable names for propensity score model.
Returns
-------
dict
Results dictionary compatible with LWDIDResults constructor.
"""
# Compute match rate, handling None n_matched by falling back to n_treated.
n_matched_attr = getattr(psm_result, 'n_matched', None)
n_matched = n_matched_attr if n_matched_attr is not None else psm_result.n_treated
if psm_result.n_treated > 0:
raw_rate = n_matched / psm_result.n_treated
match_rate = max(0.0, min(1.0, raw_rate))
else:
match_rate = 0.0
return {
'att': psm_result.att,
'se_att': psm_result.se,
't_stat': psm_result.t_stat,
'pvalue': psm_result.pvalue,
'ci_lower': psm_result.ci_lower,
'ci_upper': psm_result.ci_upper,
'nobs': psm_result.n_treated + psm_result.n_control,
'df_resid': psm_result.n_treated + psm_result.n_control - 2,
'df_inference': psm_result.n_treated + psm_result.n_control - 2,
'vce_type': vce if vce is not None else 'psm',
'cluster_var': cluster_var,
'n_clusters': None,
'controls_used': controls is not None and len(controls) > 0,
'controls': controls if controls else [],
'controls_spec': None,
'n_treated_sample': psm_result.n_treated,
'n_control_sample': psm_result.n_control,
'params': None,
'bse': None,
'vcov': None,
'resid': None,
'diagnostics': psm_result.diagnostics if hasattr(psm_result, 'diagnostics') else None,
'propensity_scores': psm_result.propensity_scores,
# PSM uses direct matching without IPW weights.
'weights_cv': np.nan,
'n_matched': n_matched,
'match_rate': match_rate,
'estimator': 'psm',
}
def _estimate_period_effects_ipw(
data: pd.DataFrame,
ydot: str,
d: str,
tindex: str,
tpost1: int,
Tmax: int,
estimator: str,
controls: list[str] | None,
ps_controls: list[str] | None,
trim_threshold: float,
n_neighbors: int,
caliper: float | None,
with_replacement: bool,
match_order: str,
period_labels: dict,
alpha: float = 0.05,
) -> pd.DataFrame:
"""
Estimate period-specific treatment effects using propensity score methods.
Applies the specified estimator (IPW, IPWRA, or PSM) to each post-treatment
period cross-section independently, producing period-specific ATT estimates.
Parameters
----------
data : pd.DataFrame
Panel data containing the transformed outcome variable.
ydot : str
Column name of the transformed outcome.
d : str
Column name of the binary treatment indicator.
tindex : str
Column name of the time period index.
tpost1 : int
Index of the first post-treatment period.
Tmax : int
Index of the last period in the panel.
estimator : {'ipw', 'ipwra', 'psm'}
Estimation method to apply at each period.
controls : list of str or None
Control variables for the outcome model (IPWRA only).
ps_controls : list of str or None
Control variables for the propensity score model.
trim_threshold : float
Propensity score trimming threshold in (0, 0.5).
n_neighbors : int
Number of nearest neighbors for PSM matching.
caliper : float or None
Maximum propensity score distance for PSM matches.
with_replacement : bool
Whether PSM allows control unit reuse.
match_order : str
Order for processing treated units in PSM without replacement.
period_labels : dict
Mapping from time index to human-readable period labels.
alpha : float, default=0.05
Significance level for confidence intervals.
Returns
-------
pd.DataFrame
Period-specific estimates with columns: 'period', 'tindex', 'beta',
'se', 'ci_lower', 'ci_upper', 'tstat', 'pval', 'N'.
"""
results_list = []
# Return empty DataFrame if no post-treatment periods exist.
if tpost1 > Tmax:
warnings.warn(
f"First post-treatment period ({tpost1}) is after the last period ({Tmax}). "
f"No period-specific effects can be estimated. "
f"This may indicate data issues or incorrect post indicator.",
UserWarning,
stacklevel=4
)
return pd.DataFrame(columns=[
'period', 'tindex', 'beta', 'se', 'ci_lower', 'ci_upper', 'tstat', 'pval', 'N'
])
# Track periods with estimation issues for consolidated warnings.
empty_periods = []
insufficient_periods = []
failed_periods = []
for t in range(tpost1, Tmax + 1):
mask_t = (data[tindex] == t)
data_t = data[mask_t].copy()
period_label = period_labels.get(t, str(t))
if len(data_t) == 0:
empty_periods.append(period_label)
results_list.append({
'period': period_label,
'tindex': t,
'beta': np.nan,
'se': np.nan,
'ci_lower': np.nan,
'ci_upper': np.nan,
'tstat': np.nan,
'pval': np.nan,
'N': 0
})
continue
n_treated_t = int(data_t[d].sum())
n_control_t = int(len(data_t) - n_treated_t)
if n_treated_t == 0 or n_control_t == 0:
insufficient_periods.append(
f"{period_label} (N_treated={n_treated_t}, N_control={n_control_t})"
)
results_list.append({
'period': period_label,
'tindex': t,
'beta': np.nan,
'se': np.nan,
'ci_lower': np.nan,
'ci_upper': np.nan,
'tstat': np.nan,
'pval': np.nan,
'N': n_treated_t + n_control_t
})
continue
try:
if estimator == 'ipw':
result_t = estimate_ipw(
data=data_t,
y=ydot,
d=d,
propensity_controls=ps_controls,
trim_threshold=trim_threshold,
alpha=alpha,
return_diagnostics=False,
gvar_col=None,
ivar_col=None,
cohort_g=None,
period_r=None,
)
elif estimator == 'ipwra':
result_t = estimate_ipwra(
data=data_t,
y=ydot,
d=d,
controls=controls,
propensity_controls=ps_controls,
trim_threshold=trim_threshold,
alpha=alpha,
return_diagnostics=False,
gvar_col=None,
ivar_col=None,
cohort_g=None,
period_r=None,
)
elif estimator == 'psm':
result_t = estimate_psm(
data=data_t,
y=ydot,
d=d,
propensity_controls=ps_controls,
n_neighbors=n_neighbors,
caliper=caliper,
with_replacement=with_replacement,
match_order=match_order,
alpha=alpha,
return_diagnostics=False,
gvar_col=None,
ivar_col=None,
cohort_g=None,
period_r=None,
)
results_list.append({
'period': period_label,
'tindex': t,
'beta': result_t.att,
'se': result_t.se,
'ci_lower': result_t.ci_lower,
'ci_upper': result_t.ci_upper,
'tstat': result_t.t_stat,
'pval': result_t.pvalue,
'N': result_t.n_treated + result_t.n_control
})
except (ValueError, np.linalg.LinAlgError, RuntimeError, KeyError) as e:
failed_periods.append(f"{period_label} ({type(e).__name__})")
results_list.append({
'period': period_label,
'tindex': t,
'beta': np.nan,
'se': np.nan,
'ci_lower': np.nan,
'ci_upper': np.nan,
'tstat': np.nan,
'pval': np.nan,
'N': n_treated_t + n_control_t
})
if empty_periods:
warnings.warn(
f"{len(empty_periods)} period(s) contain no observations: "
f"{', '.join(empty_periods[:5])}"
f"{' ...' if len(empty_periods) > 5 else ''}. "
f"Results for these periods set to NaN.",
UserWarning,
stacklevel=4
)
if insufficient_periods:
warnings.warn(
f"{len(insufficient_periods)} period(s) have insufficient treated/control units: "
f"{', '.join(insufficient_periods[:3])}"
f"{' ...' if len(insufficient_periods) > 3 else ''}. "
f"Results for these periods set to NaN.",
UserWarning,
stacklevel=4
)
if failed_periods:
warnings.warn(
f"{len(failed_periods)} period(s) failed {estimator.upper()} estimation: "
f"{', '.join(failed_periods[:3])}"
f"{' ...' if len(failed_periods) > 3 else ''}. "
f"Results for these periods set to NaN.",
UserWarning,
stacklevel=4
)
return pd.DataFrame(results_list)
[docs]
def lwdid(
data: pd.DataFrame,
y: str,
d: str | None = None,
ivar: str | None = None,
tvar: str | list[str] | None = None,
post: str | None = None,
rolling: str = 'demean',
*,
gvar: str | None = None,
control_group: str = 'not_yet_treated',
estimator: str = 'ra',
aggregate: str = 'cohort',
balanced_panel: str = 'warn',
ps_controls: list[str] | None = None,
trim_threshold: float = 0.01,
return_diagnostics: bool = False,
n_neighbors: int = 1,
caliper: float | None = None,
with_replacement: bool = True,
match_order: str = 'data',
vce: str | None = None,
controls: list[str] | None = None,
cluster_var: str | None = None,
alpha: float = 0.05,
ri: bool = False,
rireps: int = 1000,
seed: int | None = None,
ri_method: str = 'bootstrap',
graph: bool = False,
gid: str | int | None = None,
graph_options: dict | None = None,
season_var: str | None = None,
Q: int = 4,
auto_detect_frequency: bool = False,
include_pretreatment: bool = False,
pretreatment_test: bool = True,
pretreatment_alpha: float = 0.05,
exclude_pre_periods: int = 0,
**kwargs,
) -> LWDIDResults:
"""
Difference-in-differences estimator with unit-specific transformations.
Implements the rolling transformation approach for DiD estimation,
supporting three methodological scenarios:
1. **Small-sample common timing**: Exact t-based inference under classical
linear model assumptions.
2. **Large-sample common timing**: Asymptotic inference with
heteroskedasticity-robust standard errors.
3. **Staggered adoption**: Cohort-time specific effect estimation with
flexible control group strategies.
The transformation removes unit-specific pre-treatment patterns, converting
panel DiD into a cross-sectional treatment effects problem.
Parameters
----------
data : pd.DataFrame
Panel data in long format with one row per unit-time observation.
Each (unit, time) combination must be unique. Requires at least 3 units.
y : str
Column name of the outcome variable.
d : str, optional
Column name of the unit-level treatment indicator (required for common
timing mode). Must be time-invariant: non-zero for treated units, zero
for control units. Ignored in staggered mode.
ivar : str
Column name of the unit identifier.
tvar : str or list of str
Time variable specification. For annual data, a single column name.
For quarterly data, a list of two column names [year_var, quarter_var]
where quarter_var contains values in {1, 2, 3, 4}.
post : str, optional
Column name of the post-treatment indicator (required for common timing
mode). Internally binarized: non-zero values indicate post-treatment
periods. Must be monotone non-decreasing in time (no treatment reversals).
rolling : {'demean', 'detrend', 'demeanq', 'detrendq'}, default='demean'
Transformation method (case-insensitive):
- 'demean': Remove unit-specific pre-treatment mean.
- 'detrend': Remove unit-specific linear time trend.
- 'demeanq': Demeaning with seasonal fixed effects. Requires ``season_var``
and ``Q`` parameters. Supports quarterly (Q=4), monthly (Q=12), or
weekly (Q=52) data.
- 'detrendq': Detrending with seasonal fixed effects. Requires ``season_var``
and ``Q`` parameters. Supports quarterly (Q=4), monthly (Q=12), or
weekly (Q=52) data.
All four transformation methods are supported for both common timing
and staggered adoption designs.
gvar : str, optional
Column name indicating first treatment period for staggered adoption.
If specified, activates staggered mode and ignores ``d`` and ``post``.
Valid values: positive integers (treatment cohort), 0/inf/NaN (never-treated).
control_group : {'not_yet_treated', 'never_treated', 'all_others'}, default='not_yet_treated'
Control group composition for staggered adoption:
- 'not_yet_treated': Never-treated plus not-yet-treated units.
- 'never_treated': Only never-treated units.
- 'all_others': All units not in the treated cohort (including already-treated units).
This option is mainly intended for replication/diagnostics and may introduce
forbidden comparisons under no-anticipation.
Auto-switched to 'never_treated' for cohort/overall aggregation.
estimator : {'ra', 'ipw', 'ipwra', 'psm'}, default='ra'
Estimation method (case-insensitive):
- 'ra': Regression adjustment via OLS on transformed outcomes.
- 'ipw': Inverse probability weighting. Requires ``controls``.
- 'ipwra': Doubly robust combining IPW with RA. Requires ``controls``.
- 'psm': Propensity score matching. Requires ``controls``.
aggregate : {'none', 'cohort', 'overall'}, default='cohort'
Aggregation level for staggered adoption:
- 'none': Return cohort-time specific effects only.
- 'cohort': Aggregate to cohort-specific effects.
- 'overall': Aggregate to a single weighted overall effect.
balanced_panel : {'warn', 'error', 'ignore'}, default='warn'
How to handle unbalanced panels (units with different observation counts):
- 'warn': Issue a warning with selection mechanism diagnostics (default).
- 'error': Raise UnbalancedPanelError if panel is unbalanced.
- 'ignore': Silently proceed without warnings.
Selection may depend on time-invariant heterogeneity but not on
shocks to the untreated potential outcome. Use
``diagnose_selection_mechanism()`` for detailed diagnostics.
ps_controls : list of str, optional
Control variables for propensity score model. If None, uses ``controls``.
trim_threshold : float, default=0.01
Propensity score trimming threshold. Observations with propensity scores
outside [trim_threshold, 1 - trim_threshold] are excluded.
return_diagnostics : bool, default=False
Whether to include propensity score diagnostics in results.
n_neighbors : int, default=1
Number of nearest neighbors for PSM matching.
caliper : float, optional
Maximum propensity score distance for PSM, in units of PS standard
deviation. Treated units without valid matches are dropped.
with_replacement : bool, default=True
Whether PSM allows control units to be matched multiple times.
match_order : {'data', 'random', 'largest', 'smallest'}, default='data'
Order for processing treated units in without-replacement PSM:
- 'data': Original data order.
- 'random': Randomized order (use ``seed`` for reproducibility).
- 'largest': Prioritize units with extreme propensity scores.
- 'smallest': Prioritize units with propensity scores near 0.5.
vce : {None, 'robust', 'hc0', 'hc1', 'hc2', 'hc3', 'hc4', 'cluster'}, optional
Variance estimator (case-insensitive):
- None: Homoskedastic OLS standard errors.
- 'hc0': White heteroskedasticity-robust.
- 'robust'/'hc1': HC1 with degrees-of-freedom correction N/(N-K).
- 'hc2': Leverage-adjusted using (1 - h_ii)^{-1}.
- 'hc3': Small-sample adjusted using (1 - h_ii)^{-2}.
- 'hc4': Adaptive leverage correction.
- 'cluster': Cluster-robust (requires ``cluster_var``).
controls : list of str, optional
Time-invariant control variables for outcome regression.
cluster_var : str, optional
Column name for clustering (required when vce='cluster').
alpha : float, default=0.05
Significance level for confidence intervals.
ri : bool, default=False
Whether to perform randomization inference for the null H0: ATT=0.
rireps : int, default=1000
Number of randomization inference replications.
seed : int, optional
Random seed for reproducibility in randomization inference.
ri_method : {'bootstrap', 'permutation'}, default='bootstrap'
Resampling method for randomization inference:
- 'bootstrap': With-replacement resampling.
- 'permutation': Fisher's exact permutation test.
graph : bool, default=False
Whether to generate a plot of transformed outcomes over time.
gid : str or int, optional
Specific unit identifier to highlight in the plot.
graph_options : dict, optional
Additional plotting options passed to the visualization function.
season_var : str, optional
Column name of seasonal indicator variable for seasonal transformations
(demeanq, detrendq). Values should be integers from 1 to Q representing
seasonal periods (e.g., quarters 1-4, months 1-12, or weeks 1-52).
This parameter is preferred over the legacy ``quarter`` parameter in
``tvar`` for non-quarterly seasonal data.
Q : int, default=4
Number of seasonal periods per cycle. Used with seasonal transformations
(demeanq, detrendq). Common values:
- 4: Quarterly data (default)
- 12: Monthly data
- 52: Weekly data
Must match the range of values in ``season_var`` (1 to Q).
auto_detect_frequency : bool, default=False
Whether to automatically detect data frequency and set Q accordingly.
When True, the function analyzes the time variable to infer whether
data is quarterly (Q=4), monthly (Q=12), or weekly (Q=52).
- If detection succeeds with high confidence, Q is set automatically.
- If detection fails or has low confidence, a warning is issued and
the explicit Q value is used.
- An explicit Q value always overrides auto-detection when both are
specified (Q != 4 and auto_detect_frequency=True).
This parameter is useful when working with datasets of unknown frequency
or when building generic analysis pipelines.
include_pretreatment : bool, default=False
Whether to compute pre-treatment transformed outcomes and ATT estimates
for parallel trends assessment. Only applicable in staggered mode.
When True:
- Applies rolling transformations to pre-treatment periods using
future pre-treatment periods {t+1, ..., g-1} as reference.
- Estimates pre-treatment ATT for each (cohort, period) pair.
- Stores results in ``att_pre_treatment`` attribute of LWDIDResults.
- Enables extended event study visualization with pre-treatment effects.
Under the parallel trends assumption, pre-treatment ATT estimates
should be statistically indistinguishable from zero.
pretreatment_test : bool, default=True
Whether to perform parallel trends statistical test when
``include_pretreatment=True``. The test includes:
- Individual t-tests for each pre-treatment period ATT.
- Joint F-test for H0: all pre-treatment ATT = 0.
Results are stored in ``parallel_trends_test`` attribute.
pretreatment_alpha : float, default=0.05
Significance level for parallel trends test. Used for determining
``reject_null`` in the test results.
exclude_pre_periods : int, default=0
Number of pre-treatment periods to exclude immediately before treatment.
Used to address potential anticipation effects when the no-anticipation
assumption may be violated.
When ``exclude_pre_periods > 0``:
- The specified number of periods immediately before treatment are
excluded from the pre-treatment sample used for transformation.
- For common timing: excludes the last k pre-treatment periods.
- For staggered adoption: excludes k periods before each cohort's
treatment date.
This implements a robustness check for testing sensitivity to
anticipation effects.
Example: If treatment occurs at t=6 and ``exclude_pre_periods=2``,
periods t=4 and t=5 are excluded from the pre-treatment sample.
Returns
-------
LWDIDResults
Results object with the following key attributes:
- att : Average treatment effect on the treated.
- se_att : Standard error of ATT.
- t_stat : t-statistic for H0: ATT=0.
- pvalue : Two-sided p-value.
- ci_lower, ci_upper : Confidence interval bounds.
- df_inference : Degrees of freedom for inference.
- nobs : Number of observations in estimation sample.
- n_treated, n_control : Unit counts by treatment status.
- att_by_period : Period-specific ATT estimates (DataFrame).
- ri_pvalue : Randomization inference p-value (if ri=True).
- att_pre_treatment : Pre-treatment ATT estimates (if include_pretreatment=True).
- parallel_trends_test : Parallel trends test results (if include_pretreatment=True).
- include_pretreatment : Whether pre-treatment dynamics were computed.
Key methods: summary(), plot(), to_excel(), to_csv(), to_latex(),
get_diagnostics(), plot_event_study().
Raises
------
MissingRequiredColumnError
Required columns not found in data.
InvalidRollingMethodError
Invalid rolling method specified.
InsufficientDataError
Insufficient sample size or pre-/post-treatment observations.
NoTreatedUnitsError
No treated units in data.
NoControlUnitsError
No control units in data.
InsufficientPrePeriodsError
Insufficient pre-treatment periods for the chosen transformation.
NoNeverTreatedError
Cohort/overall aggregation requested but no never-treated units exist.
Notes
-----
Mode selection:
- **Common timing** (gvar=None): Requires ``d``, ``post``, ``rolling``.
- **Staggered adoption** (gvar specified): Requires ``gvar``, ``rolling``.
Confidence intervals use t-distribution critical values with degrees of
freedom N-k (homoskedastic) or G-1 (cluster-robust).
See Also
--------
LWDIDResults : Detailed documentation of the results container.
"""
from .exceptions import UnbalancedPanelError
# Validate unknown kwargs to catch parameter typos.
# Reserved for backward compatibility: riseed is an alias for seed.
_KNOWN_KWARGS = {'riseed'}
unknown_kwargs = set(kwargs.keys()) - _KNOWN_KWARGS
if unknown_kwargs:
warnings.warn(
f"Unknown keyword argument(s) ignored: {sorted(unknown_kwargs)}. "
f"Valid extra arguments: {sorted(_KNOWN_KWARGS)}.",
UserWarning,
stacklevel=2
)
# Validate balanced_panel parameter
if balanced_panel not in ('warn', 'error', 'ignore'):
raise ValueError(
f"balanced_panel must be 'warn', 'error', or 'ignore', got '{balanced_panel}'"
)
if vce is not None:
if not isinstance(vce, str):
raise TypeError(
f"Parameter 'vce' must be a string or None, got {type(vce).__name__}.\n"
f"Valid values: None, 'robust', 'hc0', 'hc1', 'hc2', 'hc3', 'hc4', 'cluster'"
)
vce = vce.lower()
_validate_psm_params(n_neighbors, caliper, with_replacement, match_order)
# Validate Q parameter for seasonal transformations.
if not isinstance(Q, (int, np.integer)):
raise TypeError(
f"Parameter 'Q' must be an integer, got {type(Q).__name__}.\n"
f"Common values: 4 (quarterly), 12 (monthly), 52 (weekly)."
)
if Q < 2:
raise ValueError(
f"Parameter 'Q' must be >= 2, got {Q}.\n"
f"Q represents the number of seasonal periods per cycle.\n"
f"Common values: 4 (quarterly), 12 (monthly), 52 (weekly)."
)
# Validate season_var parameter.
if season_var is not None and not isinstance(season_var, str):
raise TypeError(
f"Parameter 'season_var' must be a string or None, got {type(season_var).__name__}.\n"
f"Specify the column name containing seasonal values (1 to Q)."
)
# Auto-detect frequency if requested and using seasonal transformations.
# Detection only applies when rolling method requires seasonal adjustment.
if auto_detect_frequency:
if not isinstance(auto_detect_frequency, bool):
raise TypeError(
f"Parameter 'auto_detect_frequency' must be a boolean, "
f"got {type(auto_detect_frequency).__name__}."
)
# Only auto-detect for seasonal transformations.
rolling_lower = rolling.lower() if isinstance(rolling, str) else ''
if rolling_lower in ('demeanq', 'detrendq'):
# Determine time variable for detection.
tvar_for_detection = tvar if isinstance(tvar, str) else (tvar[0] if tvar else None)
if tvar_for_detection is not None and ivar is not None:
try:
detection_result = validation.detect_frequency(
data, tvar=tvar_for_detection, ivar=ivar
)
detected_Q = detection_result.get('Q')
confidence = detection_result.get('confidence', 0)
frequency = detection_result.get('frequency')
# Check if user explicitly set Q (not default value).
user_specified_Q = Q != 4
if detected_Q is not None and confidence >= 0.5:
if user_specified_Q and Q != detected_Q:
# User explicitly set Q, warn about mismatch but use user's value.
logger.warning(
f"Auto-detected frequency '{frequency}' (Q={detected_Q}) "
f"differs from explicit Q={Q}. Using explicit Q={Q}."
)
else:
# Use detected Q value.
Q = detected_Q
logger.info(
f"Auto-detected data frequency: {frequency} (Q={Q}, "
f"confidence={confidence:.2f})"
)
elif detected_Q is None or confidence < 0.5:
# Detection failed or low confidence.
warnings.warn(
f"Could not reliably detect data frequency "
f"(confidence={confidence:.2f}). Using Q={Q}. "
f"Consider setting Q explicitly for seasonal transformations.",
UserWarning,
stacklevel=2
)
except Exception as e:
# Detection failed, use default Q.
warnings.warn(
f"Frequency auto-detection failed: {e}. Using Q={Q}.",
UserWarning,
stacklevel=2
)
else:
warnings.warn(
"Cannot auto-detect frequency: tvar or ivar not specified. Using Q={Q}.",
UserWarning,
stacklevel=2
)
# Check panel balance if balanced_panel='error'
if balanced_panel == 'error' and ivar is not None:
tvar_col = tvar if isinstance(tvar, str) else tvar[0]
if tvar_col in data.columns and ivar in data.columns:
panel_counts = data.groupby(ivar)[tvar_col].count()
is_balanced = panel_counts.nunique() == 1
if not is_balanced:
min_obs = int(panel_counts.min())
max_obs = int(panel_counts.max())
n_incomplete = int((panel_counts < max_obs).sum())
raise UnbalancedPanelError(
f"Unbalanced panel detected: {n_incomplete} units have incomplete "
f"observations (range: {min_obs}-{max_obs}). "
f"Set balanced_panel='warn' to proceed with warnings, or create "
f"a balanced subsample. Use diagnose_selection_mechanism() for "
f"detailed diagnostics on selection bias risk.",
min_obs=min_obs,
max_obs=max_obs,
n_incomplete_units=n_incomplete,
)
# Dispatch to staggered or common timing implementation.
if gvar is not None:
if d is not None or post is not None:
warnings.warn(
"Both gvar and d/post parameters provided. Staggered mode takes precedence. "
"The d and post parameters will be ignored in staggered mode.",
UserWarning,
stacklevel=2
)
return _lwdid_staggered(
data=data, y=y, ivar=ivar, tvar=tvar, gvar=gvar,
rolling=rolling, control_group=control_group,
estimator=estimator, aggregate=aggregate,
ps_controls=ps_controls, trim_threshold=trim_threshold,
return_diagnostics=return_diagnostics,
n_neighbors=n_neighbors, caliper=caliper,
with_replacement=with_replacement,
match_order=match_order,
vce=vce, controls=controls, cluster_var=cluster_var,
alpha=alpha,
ri=ri, rireps=rireps, seed=seed, ri_method=ri_method,
graph=graph, gid=gid, graph_options=graph_options,
season_var=season_var, Q=Q,
include_pretreatment=include_pretreatment,
pretreatment_test=pretreatment_test,
pretreatment_alpha=pretreatment_alpha,
exclude_pre_periods=exclude_pre_periods,
**kwargs
)
else:
# Common timing mode: validate estimator parameter.
# Type check required before calling .lower() to avoid AttributeError.
if estimator is not None and not isinstance(estimator, str):
raise TypeError(
f"estimator must be a string, got {type(estimator).__name__}.\n"
f"Valid values: ('ra', 'ipw', 'ipwra', 'psm').\n"
f"Example: lwdid(..., estimator='ipwra')"
)
estimator_lower = estimator.lower() if estimator else 'ra'
VALID_ESTIMATORS_COMMON = ('ra', 'ipw', 'ipwra', 'psm')
if estimator_lower not in VALID_ESTIMATORS_COMMON:
raise ValueError(
f"Invalid estimator='{estimator}'.\n"
f"Valid values for common timing mode: {VALID_ESTIMATORS_COMMON}"
)
# Check for staggered-only parameters (control_group, aggregate).
ignored_staggered_params = []
if control_group != 'not_yet_treated':
ignored_staggered_params.append(f"control_group='{control_group}'")
if aggregate != 'cohort':
ignored_staggered_params.append(f"aggregate='{aggregate}'")
# For RA estimator, IPW-related parameters are ignored.
if estimator_lower == 'ra':
if ps_controls is not None:
ignored_staggered_params.append(f"ps_controls={ps_controls}")
if trim_threshold != 0.01:
ignored_staggered_params.append(f"trim_threshold={trim_threshold}")
if return_diagnostics:
ignored_staggered_params.append("return_diagnostics=True")
if n_neighbors != 1:
ignored_staggered_params.append(f"n_neighbors={n_neighbors}")
if caliper is not None:
ignored_staggered_params.append(f"caliper={caliper}")
if not with_replacement:
ignored_staggered_params.append("with_replacement=False")
if match_order != 'data':
ignored_staggered_params.append(f"match_order='{match_order}'")
# Validate control variables based on estimator type.
# IPWRA: requires controls (outcome model), ps_controls optional (defaults to controls).
# IPW/PSM: requires controls OR ps_controls (propensity score model only).
if estimator_lower == 'ipwra' and not controls:
raise ValueError(
f"estimator='ipwra' requires 'controls' parameter for outcome model.\n"
f"IPWRA (doubly robust) uses controls in both outcome regression and "
f"propensity score model.\n"
f" - controls: Variables for outcome model E[Y|X,D=0]\n"
f" - ps_controls: Variables for propensity score P(D=1|X), defaults to controls"
)
elif estimator_lower in ('ipw', 'psm') and not controls and not ps_controls:
raise ValueError(
f"estimator='{estimator}' requires 'controls' or 'ps_controls' parameter.\n"
f"IPW and PSM estimators need control variables for propensity score model.\n"
f" - Use 'controls' to specify variables (also used as ps_controls by default)\n"
f" - Or use 'ps_controls' to specify propensity score model variables directly"
)
# Validate trim_threshold range for IPW/IPWRA/PSM estimators.
# At threshold=0.5, scores are clipped to [0.5, 0.5], excluding all observations.
if estimator_lower in ('ipw', 'ipwra', 'psm'):
if not (0 < trim_threshold < 0.5):
raise ValueError(
f"Invalid trim_threshold={trim_threshold}.\n"
f"trim_threshold must be in (0, 0.5) for valid propensity score trimming.\n"
f" - trim_threshold=0.01 trims PS to [0.01, 0.99] (default)\n"
f" - trim_threshold=0.05 trims PS to [0.05, 0.95]\n"
f" - trim_threshold=0.5 is invalid as it would trim all observations"
)
# For non-PSM estimators, PSM-specific parameters are ignored.
if estimator_lower in ('ipw', 'ipwra'):
if n_neighbors != 1:
ignored_staggered_params.append(f"n_neighbors={n_neighbors}")
if caliper is not None:
ignored_staggered_params.append(f"caliper={caliper}")
if not with_replacement:
ignored_staggered_params.append("with_replacement=False")
if match_order != 'data':
ignored_staggered_params.append(f"match_order='{match_order}'")
if ignored_staggered_params:
warnings.warn(
f"Common timing mode (gvar=None): the following parameters "
f"are ignored: {', '.join(ignored_staggered_params)}. "
f"To use control_group/aggregate, specify the 'gvar' parameter for staggered adoption.",
UserWarning,
stacklevel=2
)
if d is None:
raise ValueError(
"Common timing mode requires 'd' parameter (unit-level treatment indicator).\n"
"If your data has staggered adoption (units treated at different times), "
"use the 'gvar' parameter to specify the first treatment period column."
)
if post is None:
raise ValueError(
"Common timing mode requires 'post' parameter (post-treatment period indicator).\n"
"If your data has staggered adoption, use the 'gvar' parameter instead."
)
if ivar is None:
raise ValueError("Parameter 'ivar' (unit identifier column) is required.")
if tvar is None:
raise ValueError("Parameter 'tvar' (time variable column) is required.")
if rolling is None:
raise ValueError(
"Common timing mode requires 'rolling' parameter.\n"
"Valid values: 'demean', 'detrend', 'demeanq', 'detrendq'.\n"
" - 'demean': Remove unit-specific pre-treatment mean\n"
" - 'detrend': Remove unit-specific linear time trend\n"
" - 'demeanq': Demeaning with quarterly fixed effects\n"
" - 'detrendq': Detrending with quarterly fixed effects"
)
if not isinstance(rolling, str):
raise TypeError(
f"Parameter 'rolling' must be a string, got {type(rolling).__name__}. "
f"Valid values: 'demean', 'detrend', 'demeanq', 'detrendq'."
)
if vce is not None and vce.lower() == 'cluster' and cluster_var is None:
raise InvalidParameterError(
"vce='cluster' requires cluster_var parameter.\n"
"Specify the column name for cluster-robust standard errors."
)
if not isinstance(alpha, (int, float, np.number)):
raise TypeError(
f"Parameter 'alpha' must be numeric, got {type(alpha).__name__}.\n"
f"Example: alpha=0.05 for 95% confidence interval."
)
if hasattr(alpha, '__float__') and np.isnan(float(alpha)):
raise ValueError("Parameter 'alpha' cannot be NaN.")
if not (0 < alpha < 1):
raise ValueError(
f"Parameter 'alpha' must be between 0 and 1 (exclusive), got {alpha}.\n"
"Common values: 0.05 (95% CI), 0.10 (90% CI), 0.01 (99% CI)."
)
# Validate exclude_pre_periods parameter.
if not isinstance(exclude_pre_periods, (int, np.integer)):
raise TypeError(
f"Parameter 'exclude_pre_periods' must be an integer, "
f"got {type(exclude_pre_periods).__name__}.\n"
f"Example: exclude_pre_periods=2 to exclude 2 periods before treatment."
)
if exclude_pre_periods < 0:
raise ValueError(
f"Parameter 'exclude_pre_periods' must be non-negative, "
f"got {exclude_pre_periods}.\n"
f"Use 0 for no exclusion, or a positive integer to exclude periods."
)
if isinstance(tvar, (list, tuple)):
if len(tvar) != 2:
raise ValueError(
f"Parameter 'tvar' as a list must have exactly 2 elements "
f"[year_column, quarter_column], got {len(tvar)} elements: {tvar}"
)
if ri:
# Accept both Python int and numpy integer types for rireps.
if not isinstance(rireps, (int, np.integer)) or rireps < 1:
raise ValueError(
f"Invalid rireps={rireps}.\n"
f"rireps must be a positive integer >= 1 when ri=True.\n"
f"Recommended: rireps >= 500 for reliable p-values."
)
# Validate ri_method type before calling .lower()
if ri_method is not None and not isinstance(ri_method, str):
raise TypeError(
f"Parameter 'ri_method' must be a string or None, "
f"got {type(ri_method).__name__}.\n"
f"Valid values: 'bootstrap', 'permutation'"
)
VALID_RI_METHODS = ('bootstrap', 'permutation')
ri_method_lower = ri_method.lower() if ri_method else 'bootstrap'
if ri_method_lower not in VALID_RI_METHODS:
raise ValueError(
f"Invalid ri_method='{ri_method}'.\n"
f"Valid values: {VALID_RI_METHODS}\n"
f" - 'bootstrap': Bootstrap resampling (with replacement)\n"
f" - 'permutation': Fisher's exact permutation test (without replacement)"
)
# Normalize ri_method to lowercase for downstream use.
ri_method = ri_method_lower
# Validate seed type to ensure reproducibility.
# Float values would be silently truncated by random.seed(),
# potentially causing unexpected behavior.
if seed is not None:
if not isinstance(seed, (int, np.integer)):
raise TypeError(
f"Parameter 'seed' must be an integer or None, "
f"got {type(seed).__name__}.\n"
f"Example: lwdid(..., ri=True, seed=42)"
)
if seed < 0:
raise ValueError(
f"Parameter 'seed' must be non-negative, got {seed}.\n"
f"Use a non-negative integer for reproducible results."
)
# Data validation and transformation.
data_clean, metadata = validation.validate_and_prepare_data(
data=data,
y=y,
d=d,
ivar=ivar,
tvar=tvar,
post=post,
rolling=rolling,
controls=controls,
season_var=season_var,
)
rolling = metadata['rolling']
# Resolve season_var: prefer explicit season_var over tvar[1] for backward compatibility.
# If season_var is provided, use it; otherwise fall back to tvar[1] for quarterly data.
effective_season_var = season_var
if effective_season_var is None and not isinstance(tvar, str):
# Legacy behavior: use tvar[1] as quarter variable when tvar is a list
effective_season_var = tvar[1]
data_transformed = transformations.apply_rolling_transform(
data=data_clean,
y=y,
ivar=ivar,
tindex='tindex',
post='post_',
rolling=rolling,
tpost1=metadata['tpost1'],
quarter=tvar[1] if not isinstance(tvar, str) else None,
season_var=effective_season_var,
Q=Q,
exclude_pre_periods=exclude_pre_periods,
)
# Extract first post-treatment cross-section for ATT estimation.
# Using firstpost ensures consistent sample across all estimators.
firstpost_data = data_transformed[data_transformed['firstpost']].copy()
# Verify first post-treatment period has observations.
# Empty data would cause unclear errors in downstream estimators.
if len(firstpost_data) == 0:
raise InsufficientDataError(
"No observations found for the first post-treatment period.\n\n"
"Possible causes:\n"
" - No treated units in the data (all units have d=0)\n"
" - The 'post' indicator is never 1 (no post-treatment periods)\n"
" - All first-post observations were filtered during transformation\n"
" - Data filtering removed all eligible observations\n\n"
"How to fix:\n"
" 1. Check treatment indicator: data[d].value_counts()\n"
" 2. Check post indicator: data[post].value_counts()\n"
" 3. Verify data has post-treatment observations for treated units"
)
if estimator_lower == 'ra':
# RA uses OLS on transformed outcomes for unbiased ATT under parallel trends.
results_dict = estimation.estimate_att(
data=data_transformed,
y_transformed='ydot_postavg',
d='d_',
ivar=ivar,
controls=controls,
vce=vce,
cluster_var=cluster_var,
sample_filter=data_transformed['firstpost'],
alpha=alpha,
)
else:
# Non-RA estimators require propensity score model for weighting or matching.
ps_controls_final = ps_controls if ps_controls is not None else controls
if estimator_lower == 'ipw':
ipw_result = estimate_ipw(
data=firstpost_data,
y='ydot_postavg',
d='d_',
propensity_controls=ps_controls_final,
trim_threshold=trim_threshold,
alpha=alpha,
return_diagnostics=return_diagnostics,
# Common timing mode bypasses cohort-specific logic.
gvar_col=None,
ivar_col=None,
cohort_g=None,
period_r=None,
)
results_dict = _convert_ipw_result_to_dict(
ipw_result, alpha, vce, cluster_var, controls, ps_controls_final
)
elif estimator_lower == 'ipwra':
ipwra_result = estimate_ipwra(
data=firstpost_data,
y='ydot_postavg',
d='d_',
controls=controls,
propensity_controls=ps_controls_final,
trim_threshold=trim_threshold,
alpha=alpha,
return_diagnostics=return_diagnostics,
# Common timing mode bypasses cohort-specific logic.
gvar_col=None,
ivar_col=None,
cohort_g=None,
period_r=None,
)
results_dict = _convert_ipwra_result_to_dict(
ipwra_result, alpha, vce, cluster_var, controls
)
elif estimator_lower == 'psm':
psm_result = estimate_psm(
data=firstpost_data,
y='ydot_postavg',
d='d_',
propensity_controls=ps_controls_final,
n_neighbors=n_neighbors,
caliper=caliper,
with_replacement=with_replacement,
match_order=match_order,
alpha=alpha,
return_diagnostics=return_diagnostics,
# Common timing mode bypasses cohort-specific logic.
gvar_col=None,
ivar_col=None,
cohort_g=None,
period_r=None,
)
results_dict = _convert_psm_result_to_dict(
psm_result, alpha, vce, cluster_var, controls
)
# Store estimator-specific diagnostics if requested.
if return_diagnostics:
metadata['estimator_diagnostics'] = results_dict.get('diagnostics')
# Construct human-readable period labels preserving numeric precision.
# Integer years display without decimals for cleaner output formatting.
if isinstance(tvar, str):
period_labels = {}
for t, year in data_transformed.groupby('tindex')[tvar].first().items():
if pd.notna(year):
if year == int(year):
period_labels[t] = str(int(year))
else:
# Preserve decimal part for non-integer years.
period_labels[t] = str(year)
else:
period_labels[t] = f"T{t}"
else:
year_var, quarter_var = tvar[0], tvar[1]
period_labels = {}
for t in data_transformed['tindex'].unique():
row = data_transformed[data_transformed['tindex'] == t].iloc[0]
year_val = row[year_var]
quarter_val = row[quarter_var]
if pd.notna(year_val) and pd.notna(quarter_val):
# Preserve non-integer values for display precision.
year_str = str(int(year_val)) if year_val == int(year_val) else str(year_val)
quarter_str = str(int(quarter_val)) if quarter_val == int(quarter_val) else str(quarter_val)
period_labels[t] = f"{year_str}q{quarter_str}"
else:
period_labels[t] = f"T{t}"
Tmax = int(data_transformed['tindex'].max())
controls_spec = results_dict.get('controls_spec', None)
# Estimate period-specific effects using the appropriate estimator.
if estimator_lower == 'ra':
# RA uses OLS-based period effect estimation.
period_df = estimation.estimate_period_effects(
data=data_transformed,
ydot='ydot',
d='d_',
tindex='tindex',
tpost1=metadata['tpost1'],
Tmax=Tmax,
controls_spec=controls_spec,
vce=vce,
cluster_var=cluster_var,
period_labels=period_labels,
alpha=alpha,
)
else:
# IPW/IPWRA/PSM use propensity-based period effect estimation.
period_df = _estimate_period_effects_ipw(
data=data_transformed,
ydot='ydot',
d='d_',
tindex='tindex',
tpost1=metadata['tpost1'],
Tmax=Tmax,
estimator=estimator_lower,
controls=controls,
ps_controls=ps_controls_final,
trim_threshold=trim_threshold,
n_neighbors=n_neighbors,
caliper=caliper,
with_replacement=with_replacement,
match_order=match_order,
period_labels=period_labels,
alpha=alpha,
)
# Combine average and period-specific effects into unified output.
avg_row = pd.DataFrame([{
'period': 'average',
'tindex': '-',
'beta': results_dict['att'],
'se': results_dict['se_att'],
'ci_lower': results_dict['ci_lower'],
'ci_upper': results_dict['ci_upper'],
'tstat': results_dict['t_stat'],
'pval': results_dict['pvalue'],
'N': results_dict['nobs']
}])
avg_row['is_avg'] = True
period_df['is_avg'] = False
att_by_period = pd.concat([avg_row, period_df], ignore_index=True)
att_by_period = att_by_period.sort_values(
['is_avg', 'tindex'], ascending=[False, True]
)
att_by_period = att_by_period.drop(columns=['is_avg']).reset_index(drop=True)
att_by_period['tindex'] = att_by_period['tindex'].astype(str)
att_by_period = att_by_period[[
'period', 'tindex', 'beta', 'se', 'ci_lower', 'ci_upper', 'tstat', 'pval', 'N'
]]
# Initialize RI variables before conditional execution block.
ri_result = None
actual_seed = None
if ri:
# Handle legacy riseed parameter with explicit type validation.
# Invalid input raises an error rather than silently using a random seed.
if seed is None and 'riseed' in kwargs:
riseed_val = kwargs['riseed']
if riseed_val is not None:
if isinstance(riseed_val, (int, np.integer)):
seed = int(riseed_val)
elif isinstance(riseed_val, str):
try:
seed = int(riseed_val)
except ValueError:
raise TypeError(
f"riseed must be an integer or integer-convertible string, "
f"got '{riseed_val}'. Use seed parameter for explicit control."
)
else:
raise TypeError(
f"riseed must be an integer, got {type(riseed_val).__name__}. "
f"Use seed parameter for explicit control."
)
actual_seed = seed if seed is not None else _generate_ri_seed()
firstpost_df = data_transformed.loc[data_transformed['firstpost']].copy()
if metadata.get('id_mapping') is not None:
firstpost_df.attrs['id_mapping'] = metadata['id_mapping']
try:
ri_result = randomization_inference(
firstpost_df=firstpost_df,
y_col='ydot_postavg',
d_col='d_',
ivar=ivar,
rireps=rireps,
seed=actual_seed,
att_obs=results_dict['att'],
ri_method=ri_method,
controls=controls,
)
except (RandomizationError, ValueError, np.linalg.LinAlgError) as e:
# RandomizationError: RI-specific failures (insufficient data, invalid params)
# ValueError: Data issues during resampling
# LinAlgError: Singular matrix in permuted regressions
warnings.warn(
f"Randomization inference failed: {type(e).__name__}: {e}. "
f"ATT estimation results are still valid.",
UserWarning,
stacklevel=3
)
ri_result = {
'p_value': np.nan,
'ri_method': ri_method,
'ri_valid': 0,
'ri_failed': -1,
'ri_error': str(e),
}
results_dict['alpha'] = alpha
metadata['alpha'] = alpha
results = LWDIDResults(results_dict, metadata, att_by_period)
results.data = data_transformed
if metadata.get('id_mapping') is not None:
results.data.attrs['id_mapping'] = metadata['id_mapping']
if ri:
results.ri_pvalue = ri_result['p_value']
results.rireps = int(rireps)
results.ri_seed = int(actual_seed)
results.ri_method = ri_result['ri_method']
results.ri_valid = ri_result['ri_valid']
results.ri_failed = ri_result['ri_failed']
results.ri_error = ri_result.get('ri_error', None)
if graph:
try:
results.plot(gid=gid, graph_options=graph_options)
except Exception as e:
warnings.warn(
f"Plotting failed: {type(e).__name__}: {str(e)}. "
f"The estimation results are unaffected.",
UserWarning,
stacklevel=3 # _lwdid_classic() is level 2, so +1 to point to user code
)
return results
def _validate_control_group_for_aggregate(
aggregate: str,
control_group: str,
has_never_treated: bool,
n_never_treated: int = 0
) -> tuple:
"""
Validate and adjust control group strategy based on aggregation level.
Cohort and overall aggregation require never-treated units as a consistent
reference group across different treatment cohorts. This function auto-
switches to 'never_treated' when needed and validates that sufficient
never-treated units exist.
Parameters
----------
aggregate : {'none', 'cohort', 'overall'}
Aggregation level for effect estimates.
control_group : {'never_treated', 'not_yet_treated', 'all_others'}
Requested control group composition strategy.
has_never_treated : bool
Whether the data contains any never-treated units.
n_never_treated : int
Number of never-treated units in the data.
Returns
-------
control_group_used : str
The control group strategy that will be used.
warning_msg : str or None
Warning message if auto-switching occurred, None otherwise.
Raises
------
NoNeverTreatedError
If aggregation requires never-treated units but none exist.
"""
warning_msg = None
control_group_used = control_group
if aggregate in ('cohort', 'overall'):
if control_group != 'never_treated':
warning_msg = (
f"{aggregate} effect estimation requires never_treated control group, "
f"automatically switched from '{control_group}' to 'never_treated'."
)
logger.info(warning_msg)
warnings.warn(warning_msg, UserWarning, stacklevel=4)
control_group_used = 'never_treated'
if not has_never_treated:
raise NoNeverTreatedError(
f"Cannot estimate {aggregate} effect: no never-treated units in data.\n"
f"Reason: {aggregate} effect requires NT units as a unified reference baseline.\n"
f" - Cohort effect: Different cohorts' transformations use different pre-treatment periods, "
f"only NT units can provide a consistent reference.\n"
f" - Overall effect: NT units are needed to compute weighted transformations across all cohorts.\n"
f"Suggestion: Use aggregate='none' to estimate (g,r)-specific effects, which can use not-yet-treated control group."
)
if n_never_treated < 2:
warnings.warn(
f"Number of never-treated units is too few (N={n_never_treated}), "
f"inference results may be unreliable. Recommended N_NT >= 2.",
UserWarning,
stacklevel=4
)
return control_group_used, warning_msg
def _lwdid_staggered(
data: pd.DataFrame,
y: str,
ivar: str,
tvar: str | list[str],
gvar: str,
rolling: str,
control_group: str,
estimator: str,
aggregate: str,
ps_controls: list[str] | None,
trim_threshold: float,
return_diagnostics: bool,
n_neighbors: int,
caliper: float | None,
with_replacement: bool,
match_order: str,
vce: str | None,
controls: list[str] | None,
cluster_var: str | None,
alpha: float,
ri: bool,
rireps: int,
seed: int | None,
ri_method: str,
graph: bool,
gid: str | int | None,
graph_options: dict | None,
season_var: str | None = None,
Q: int = 4,
include_pretreatment: bool = False,
pretreatment_test: bool = True,
pretreatment_alpha: float = 0.05,
exclude_pre_periods: int = 0,
**kwargs
) -> LWDIDResults:
"""
Estimate treatment effects under staggered adoption design.
Internal dispatcher that applies cohort-specific transformations, estimates
cohort-time ATT effects, and aggregates results according to the specified
aggregation level. This function handles all staggered adoption logic when
the ``gvar`` parameter is provided to ``lwdid()``.
Parameters
----------
data : pd.DataFrame
Panel data with unit, time, treatment cohort, and outcome variables.
y : str
Column name of the outcome variable.
ivar : str
Column name of the unit identifier.
tvar : str or list of str
Column name(s) of the time variable.
gvar : str
Column name indicating the first treatment period for each unit.
rolling : {'demean', 'detrend', 'demeanq', 'detrendq'}
Transformation method for removing pre-treatment patterns.
control_group : {'never_treated', 'not_yet_treated', 'all_others'}
Control group composition strategy.
estimator : {'ra', 'ipw', 'ipwra', 'psm'}
Treatment effect estimation method.
aggregate : {'none', 'cohort', 'overall'}
Aggregation level for the effect estimates.
ps_controls : list of str or None
Variables for the propensity score model.
trim_threshold : float
Propensity score trimming bound in (0, 0.5).
return_diagnostics : bool
Whether to include estimation diagnostics in results.
n_neighbors : int
Number of nearest neighbors for PSM matching.
caliper : float or None
Maximum propensity score distance for PSM matches.
with_replacement : bool
Whether PSM allows control unit reuse.
match_order : str
Order for processing treated units in PSM without replacement.
vce : str or None
Variance-covariance estimator type.
controls : list of str or None
Control variables for the outcome model.
cluster_var : str or None
Variable name for cluster-robust standard errors.
alpha : float
Significance level for confidence intervals.
ri : bool
Whether to perform randomization inference.
rireps : int
Number of randomization inference replications.
seed : int or None
Random seed for reproducibility.
ri_method : {'bootstrap', 'permutation'}
Randomization inference resampling method.
graph : bool
Whether to generate visualization plots.
gid : str, int, or None
Specific unit identifier to highlight in plots.
graph_options : dict or None
Additional plotting configuration options.
**kwargs
Additional keyword arguments (reserved for future use).
Returns
-------
LWDIDResults
Results object containing ATT estimates, standard errors, confidence
intervals, cohort-time specific effects, and aggregated effects.
"""
from .staggered import (
transformations as stag_trans,
estimation as stag_est,
aggregation as stag_agg
)
if graph:
warnings.warn(
"Parameter 'graph=True' is not yet supported in staggered mode.\n"
"To visualize results, use the `plot_event_study()` method on the returned "
"LWDIDResults object: `results.plot_event_study()`",
UserWarning,
stacklevel=4
)
# Parameter validation.
if ivar is None:
raise ValueError("Staggered mode requires 'ivar' parameter (unit identifier column).")
if tvar is None:
raise ValueError("Staggered mode requires 'tvar' parameter (time variable column).")
# Validate tvar list/tuple format for quarterly data support.
if isinstance(tvar, (list, tuple)):
if len(tvar) == 0:
raise ValueError(
"Parameter 'tvar' cannot be an empty list/tuple.\n"
"For annual data: tvar='year_column'\n"
"For quarterly data: tvar=['year_column', 'quarter_column']"
)
if len(tvar) > 2:
raise ValueError(
f"Parameter 'tvar' as a list must have 1-2 elements "
f"[year_column, quarter_column], got {len(tvar)} elements: {tvar}"
)
if rolling is None:
raise ValueError(
"Staggered mode requires 'rolling' parameter.\n"
"Valid values: 'demean', 'detrend', 'demeanq', 'detrendq'."
)
if not isinstance(rolling, str):
raise TypeError(
f"Parameter 'rolling' must be a string, got {type(rolling).__name__}. "
f"Valid values: 'demean', 'detrend', 'demeanq', 'detrendq'."
)
rolling_lower = rolling.lower()
VALID_ROLLING_STAGGERED = ('demean', 'detrend', 'demeanq', 'detrendq')
if rolling_lower not in VALID_ROLLING_STAGGERED:
raise ValueError(
f"Invalid rolling='{rolling}' for staggered mode.\n"
f"Valid values: {VALID_ROLLING_STAGGERED}"
)
# Validate aggregate parameter type before string operations.
if aggregate is not None and not isinstance(aggregate, str):
raise TypeError(
f"Parameter 'aggregate' must be a string or None, "
f"got {type(aggregate).__name__}.\n"
f"Valid values: 'none', 'cohort', 'overall'"
)
VALID_AGGREGATE = ('none', 'cohort', 'overall')
aggregate_lower = aggregate.lower() if aggregate else 'none'
if aggregate_lower not in VALID_AGGREGATE:
raise ValueError(
f"Invalid aggregate='{aggregate}'.\n"
f"Valid values: {VALID_AGGREGATE}\n"
f" - 'none': Return (g,r)-specific effects only\n"
f" - 'cohort': Aggregate to cohort-specific effects (τ_g)\n"
f" - 'overall': Aggregate to overall weighted effect (τ_ω)"
)
# Validate control_group parameter type before string operations.
if control_group is not None and not isinstance(control_group, str):
raise TypeError(
f"Parameter 'control_group' must be a string or None, "
f"got {type(control_group).__name__}.\n"
f"Valid values: 'never_treated', 'not_yet_treated', 'all_others'"
)
VALID_CONTROL_GROUPS = ('never_treated', 'not_yet_treated', 'all_others')
control_group_lower = control_group.lower() if control_group else 'not_yet_treated'
if control_group_lower not in VALID_CONTROL_GROUPS:
raise ValueError(
f"Invalid control_group='{control_group}'.\n"
f"Valid values: {VALID_CONTROL_GROUPS}\n"
f" - 'never_treated': Use only never-treated units as control\n"
f" - 'not_yet_treated': Use never-treated + not-yet-treated units as control\n"
f" - 'all_others': Use all non-cohort units as control (includes already-treated)"
)
# Validate estimator type before string operations.
# Default to 'ra' when estimator is None for consistency.
if estimator is not None and not isinstance(estimator, str):
raise TypeError(
f"estimator must be a string, got {type(estimator).__name__}.\n"
f"Valid values: ('ra', 'ipw', 'ipwra', 'psm').\n"
f"Example: lwdid(..., estimator='ipwra')"
)
estimator_lower = estimator.lower() if estimator else 'ra'
VALID_ESTIMATORS = ('ra', 'ipw', 'ipwra', 'psm')
if estimator_lower not in VALID_ESTIMATORS:
raise ValueError(
f"Invalid estimator='{estimator}'.\n"
f"Valid values: {VALID_ESTIMATORS}"
)
# Validate control variables based on estimator type.
# IPWRA: requires controls (outcome model), ps_controls optional (defaults to controls).
# IPW/PSM: requires controls OR ps_controls (propensity score model only).
if estimator_lower == 'ipwra' and not controls:
raise ValueError(
f"estimator='ipwra' requires 'controls' parameter for outcome model.\n"
f"IPWRA (doubly robust) uses controls in both outcome regression and "
f"propensity score model.\n"
f" - controls: Variables for outcome model E[Y|X,D=0]\n"
f" - ps_controls: Variables for propensity score P(D=1|X), defaults to controls"
)
elif estimator_lower in ('ipw', 'psm') and not controls and not ps_controls:
raise ValueError(
f"estimator='{estimator}' requires 'controls' or 'ps_controls' parameter.\n"
f"IPW and PSM estimators need control variables for propensity score model.\n"
f" - Use 'controls' to specify variables (also used as ps_controls by default)\n"
f" - Or use 'ps_controls' to specify propensity score model variables directly"
)
# Validate alpha parameter type for confidence interval calculation.
if not isinstance(alpha, (int, float, np.number)):
raise TypeError(
f"Parameter 'alpha' must be numeric, got {type(alpha).__name__}.\n"
f"Example: alpha=0.05 for 95% confidence interval."
)
if hasattr(alpha, '__float__') and np.isnan(float(alpha)):
raise ValueError("Parameter 'alpha' cannot be NaN.")
if not (0 < alpha < 1):
raise ValueError(
f"Invalid alpha={alpha}.\n"
f"alpha must be in (0, 1) for valid confidence interval calculation.\n"
f" - alpha=0.05 gives 95% CI (default)\n"
f" - alpha=0.10 gives 90% CI\n"
f" - alpha=0.01 gives 99% CI"
)
if vce is not None and vce.lower() == 'cluster' and cluster_var is None:
raise InvalidParameterError(
"vce='cluster' requires cluster_var parameter.\n"
"Specify the column name for cluster-robust standard errors."
)
if estimator_lower in ('ipw', 'ipwra', 'psm'):
# Propensity score trimming requires threshold in (0, 0.5) exclusive.
# At threshold=0.5, scores are clipped to [0.5, 0.5], excluding all observations.
if not (0 < trim_threshold < 0.5):
raise ValueError(
f"Invalid trim_threshold={trim_threshold}.\n"
f"trim_threshold must be in (0, 0.5) for valid propensity score trimming.\n"
f" - trim_threshold=0.01 trims PS to [0.01, 0.99] (default)\n"
f" - trim_threshold=0.05 trims PS to [0.05, 0.95]\n"
f" - trim_threshold=0.5 is invalid as it would trim all observations"
)
if ri:
# Accept both Python int and numpy integer types for rireps.
if not isinstance(rireps, (int, np.integer)) or rireps < 1:
raise ValueError(
f"Invalid rireps={rireps}.\n"
f"rireps must be a positive integer >= 1 when ri=True.\n"
f"Recommended: rireps >= 500 for reliable p-values."
)
# Validate ri_method type before calling .lower()
if ri_method is not None and not isinstance(ri_method, str):
raise TypeError(
f"Parameter 'ri_method' must be a string or None, "
f"got {type(ri_method).__name__}.\n"
f"Valid values: 'bootstrap', 'permutation'"
)
VALID_RI_METHODS = ('bootstrap', 'permutation')
ri_method_lower = ri_method.lower() if ri_method else 'bootstrap'
if ri_method_lower not in VALID_RI_METHODS:
raise ValueError(
f"Invalid ri_method='{ri_method}'.\n"
f"Valid values: {VALID_RI_METHODS}\n"
f" - 'bootstrap': Bootstrap resampling (with replacement)\n"
f" - 'permutation': Fisher's exact permutation test (without replacement)"
)
# Normalize ri_method to lowercase for downstream use.
ri_method = ri_method_lower
# Validate seed type to ensure reproducibility.
# Float values would be silently truncated by random.seed().
if seed is not None:
if not isinstance(seed, (int, np.integer)):
raise TypeError(
f"Parameter 'seed' must be an integer or None, "
f"got {type(seed).__name__}.\n"
f"Example: lwdid(..., ri=True, seed=42)"
)
if seed < 0:
raise ValueError(
f"Parameter 'seed' must be non-negative, got {seed}.\n"
f"Use a non-negative integer for reproducible results."
)
# Data validation and preparation.
validation_result = validate_staggered_data(
data=data,
gvar=gvar,
ivar=ivar,
tvar=tvar,
y=y,
controls=controls
)
cohorts = validation_result['cohorts']
has_never_treated = validation_result['n_never_treated'] > 0
n_never_treated = validation_result['n_never_treated']
T_max = validation_result['T_max']
T_min = validation_result['T_min']
cohort_sizes = validation_result['cohort_sizes']
for warning in validation_result.get('warnings', []):
warnings.warn(warning, UserWarning, stacklevel=3)
control_group_used, switch_warning = _validate_control_group_for_aggregate(
aggregate=aggregate_lower,
control_group=control_group_lower,
has_never_treated=has_never_treated,
n_never_treated=n_never_treated
)
# Apply transformation.
tvar_str = tvar if isinstance(tvar, str) else tvar[0]
# Select appropriate transformation function based on rolling method.
if rolling_lower == 'demean':
transform_func = stag_trans.transform_staggered_demean
data_transformed = transform_func(
data=data,
y=y,
ivar=ivar,
tvar=tvar_str,
gvar=gvar,
exclude_pre_periods=exclude_pre_periods,
)
elif rolling_lower == 'detrend':
transform_func = stag_trans.transform_staggered_detrend
data_transformed = transform_func(
data=data,
y=y,
ivar=ivar,
tvar=tvar_str,
gvar=gvar,
exclude_pre_periods=exclude_pre_periods,
)
elif rolling_lower == 'demeanq':
data_transformed = stag_trans.transform_staggered_demeanq(
data=data,
y=y,
ivar=ivar,
tvar=tvar_str,
gvar=gvar,
season_var=season_var,
Q=Q,
exclude_pre_periods=exclude_pre_periods,
)
elif rolling_lower == 'detrendq':
data_transformed = stag_trans.transform_staggered_detrendq(
data=data,
y=y,
ivar=ivar,
tvar=tvar_str,
gvar=gvar,
season_var=season_var,
Q=Q,
exclude_pre_periods=exclude_pre_periods,
)
# Estimate cohort-time effects.
ps_controls_final = ps_controls if ps_controls is not None else controls
cohort_time_effects = stag_est.estimate_cohort_time_effects(
data_transformed=data_transformed,
gvar=gvar,
ivar=ivar,
tvar=tvar_str,
controls=controls,
vce=vce,
cluster_var=cluster_var,
control_strategy=control_group_used,
estimator=estimator_lower,
transform_type=rolling_lower,
alpha=alpha,
propensity_controls=ps_controls_final,
trim_threshold=trim_threshold,
return_diagnostics=return_diagnostics,
n_neighbors=n_neighbors,
caliper=caliper,
with_replacement=with_replacement,
match_order=match_order,
)
att_by_cohort_time = pd.DataFrame([
{
'cohort': e.cohort,
'period': e.period,
'event_time': e.event_time,
'att': e.att,
'se': e.se,
'ci_lower': e.ci_lower,
'ci_upper': e.ci_upper,
't_stat': e.t_stat,
'pvalue': e.pvalue,
'n_treated': e.n_treated,
'n_control': e.n_control,
'n_total': e.n_total,
'df_resid': e.df_resid,
'df_inference': e.df_inference,
}
for e in cohort_time_effects
])
# Initialize aggregation variables for consistent scope across all code paths.
att_by_cohort = None
att_overall = None
se_overall = None
cohort_weights = {}
t_stat_overall = None
pvalue_overall = None
ci_overall = (None, None)
overall_effect = None
cohort_effects = [] # Populated when aggregate in ('cohort', 'overall').
if aggregate_lower in ('cohort', 'overall'):
cohort_effects = stag_agg.aggregate_to_cohort(
data_transformed=data_transformed,
gvar=gvar,
ivar=ivar,
tvar=tvar_str,
cohorts=cohorts,
T_max=T_max,
transform_type='demean' if rolling_lower in ('demean', 'demeanq') else 'detrend',
vce=vce,
cluster_var=cluster_var,
alpha=alpha,
)
att_by_cohort = pd.DataFrame([
{
'cohort': c.cohort,
'att': c.att,
'se': c.se,
'ci_lower': c.ci_lower,
'ci_upper': c.ci_upper,
't_stat': c.t_stat,
'pvalue': c.pvalue,
'n_periods': c.n_periods,
'n_units': c.n_units
}
for c in cohort_effects
])
if aggregate_lower == 'overall':
overall_effect = stag_agg.aggregate_to_overall(
data_transformed=data_transformed,
gvar=gvar,
ivar=ivar,
tvar=tvar_str,
transform_type='demean' if rolling_lower == 'demean' else 'detrend',
vce=vce,
cluster_var=cluster_var,
alpha=alpha,
)
att_overall = overall_effect.att
se_overall = overall_effect.se
cohort_weights = overall_effect.cohort_weights
t_stat_overall = overall_effect.t_stat
pvalue_overall = overall_effect.pvalue
ci_overall = (overall_effect.ci_lower, overall_effect.ci_upper)
# Build results object.
n_treated = int(sum(cohort_sizes.values()))
# Compute n_control based on actual control group strategy used.
# For 'never_treated': control group consists only of never-treated units.
# For 'not_yet_treated': control group varies by (g,r) and includes NT + NYT units.
if control_group_used == 'never_treated':
n_control = n_never_treated
else: # 'not_yet_treated'
# Extract maximum n_control from cohort-time effects.
# This reflects the largest control group actually used in estimation.
if len(cohort_time_effects) > 0:
n_control = max(e.n_control for e in cohort_time_effects)
else:
# Fallback: use never-treated count as control group size.
n_control = n_never_treated
# Compute df_resid and df_inference based on aggregation level.
# The degrees of freedom should reflect the underlying regression.
# Fallback df calculation accounts for number of control variables.
n_controls = len(controls) if controls else 0
estimator_for_df = estimator_lower if estimator_lower else 'ra'
if estimator_for_df == 'ra':
# RA: intercept + D + K controls + K interactions = 2 + 2K
df_fallback = n_treated + n_control - 2 - 2 * n_controls
elif estimator_for_df == 'ipwra':
# IPWRA: intercept + D + K controls = 2 + K
df_fallback = n_treated + n_control - 2 - n_controls
else:
# IPW/PSM: df = n - 2 (controls only affect PS model, not outcome).
df_fallback = n_treated + n_control - 2
df_fallback = max(1, df_fallback) # Ensure valid t-distribution quantiles.
if aggregate_lower == 'overall' and overall_effect is not None:
# For overall aggregation: use df from the overall regression
df_resid_val = overall_effect.df_resid
df_inference_val = overall_effect.df_inference
elif aggregate_lower == 'cohort' and cohort_effects:
# For cohort aggregation: use median df across cohort regressions.
# Use nanmedian to handle cases where some cohorts failed estimation.
valid_df_resid = [c.df_resid for c in cohort_effects if np.isfinite(c.df_resid)]
valid_df_inference = [c.df_inference for c in cohort_effects if np.isfinite(c.df_inference)]
if valid_df_resid:
df_resid_val = int(np.median(valid_df_resid))
else:
# Use controls-aware fallback when no valid cohort effects exist.
df_resid_val = df_fallback
if valid_df_inference:
df_inference_val = int(np.median(valid_df_inference))
else:
# Use controls-aware fallback when no valid cohort effects exist.
df_inference_val = df_fallback
elif len(cohort_time_effects) > 0:
# For none or fallback: use median df across cohort-time regressions.
# Use nanmedian to handle cases where some effects failed estimation.
valid_df_resid = [e.df_resid for e in cohort_time_effects if np.isfinite(e.df_resid)]
valid_df_inference = [e.df_inference for e in cohort_time_effects if np.isfinite(e.df_inference)]
if valid_df_resid:
df_resid_val = int(np.median(valid_df_resid))
else:
# Use controls-aware fallback when no valid effects exist.
df_resid_val = df_fallback
if valid_df_inference:
df_inference_val = int(np.median(valid_df_inference))
else:
# Use controls-aware fallback when no valid effects exist.
df_inference_val = df_fallback
else:
# Fallback to controls-aware formula when no effects available.
df_resid_val = df_fallback
df_inference_val = df_fallback
# Compute aggregated inference statistics for cohort-level aggregation.
# Weighted average ATT and SE across cohorts uses n_units as weights.
se_cohort_agg = None
att_cohort_agg = None
t_stat_cohort_agg = None
pvalue_cohort_agg = None
ci_cohort_agg = (None, None)
if aggregate_lower == 'cohort' and cohort_effects:
# Filter valid cohort effects with finite ATT and SE values.
valid_cohorts = [c for c in cohort_effects if np.isfinite(c.att) and np.isfinite(c.se)]
if valid_cohorts:
# Compute n_units-weighted average ATT across cohorts.
total_units = sum(c.n_units for c in valid_cohorts)
if total_units > 0:
weights = np.array([c.n_units / total_units for c in valid_cohorts])
atts = np.array([c.att for c in valid_cohorts])
ses = np.array([c.se for c in valid_cohorts])
# Weighted average ATT: τ̂ = Σ(w_g × τ̂_g)
att_cohort_agg = float(np.sum(weights * atts))
# SE via delta method assuming independent cohort estimates:
# Var(τ̂) = Σ(w_g² × Var(τ̂_g)) = Σ(w_g² × SE_g²)
# SE(τ̂) = sqrt(Σ(w_g² × SE_g²))
var_agg = np.sum(weights**2 * ses**2)
se_cohort_agg = float(np.sqrt(var_agg))
# Compute t-statistic and p-value using aggregated df.
if se_cohort_agg > 0 and df_inference_val > 0:
t_stat_cohort_agg = att_cohort_agg / se_cohort_agg
pvalue_cohort_agg = 2 * scipy.stats.t.sf(abs(t_stat_cohort_agg), df_inference_val)
# Confidence interval.
t_crit = scipy.stats.t.ppf(1 - alpha / 2, df_inference_val)
ci_cohort_agg = (
att_cohort_agg - t_crit * se_cohort_agg,
att_cohort_agg + t_crit * se_cohort_agg
)
else:
# All cohort-level ATT estimates are invalid (NaN or infinite).
n_total_cohorts = len(cohort_effects)
warnings.warn(
f"All {n_total_cohorts} cohort-level ATT estimates are invalid (NaN or infinite). "
f"Cohort-aggregated statistics (att_cohort_agg, se_cohort_agg, etc.) remain None. "
f"Possible causes: insufficient sample size, numerical instability, or data quality issues.",
UserWarning,
stacklevel=3
)
# Compute fallback ATT from cohort-time effects if no higher-level aggregation.
# This is used when aggregate='none' or when aggregation fails.
att_cohort_time_fallback = None
if att_overall is None and att_cohort_agg is None:
if len(att_by_cohort_time) > 0 and att_by_cohort_time['att'].notna().any():
valid_df = att_by_cohort_time.loc[att_by_cohort_time['att'].notna()]
weights_sum = valid_df['n_treated'].sum()
# ATT = E[Y(1) - Y(0) | D=1] requires D=1 observations.
# If total n_treated is zero across all cohort-time effects, ATT is not identifiable.
if weights_sum <= 0:
raise ValueError(
"Cannot compute weighted average ATT: total n_treated across all "
"cohort-time effects is zero. ATT estimation requires at least one "
"treated unit with valid outcome. This may occur due to:\n"
" 1. Propensity score trimming excluded all treated units\n"
" 2. Missing outcome values for all treated units\n"
" 3. Data quality issues in treatment indicator or outcome variable\n"
"Check data quality or adjust trim_threshold parameter."
)
att_cohort_time_fallback = float(np.average(valid_df['att'], weights=valid_df['n_treated']))
# =========================================================================
# Pre-treatment Dynamics Estimation (when include_pretreatment=True)
# =========================================================================
att_pre_treatment_df = None
parallel_trends_result = None
if include_pretreatment:
from .staggered.transformations_pre import (
transform_staggered_demean_pre,
transform_staggered_detrend_pre,
)
from .staggered.estimation_pre import (
estimate_pre_treatment_effects,
pre_treatment_effects_to_dataframe,
)
from .staggered.parallel_trends import run_parallel_trends_test
# Apply pre-treatment transformation
try:
if rolling_lower in ('demean', 'demeanq'):
data_pre_transformed = transform_staggered_demean_pre(
data=data_transformed,
y=y,
ivar=ivar,
tvar=tvar_str,
gvar=gvar,
never_treated_values=[0, np.inf],
)
pre_transform_type = 'demean'
else: # detrend, detrendq
data_pre_transformed = transform_staggered_detrend_pre(
data=data_transformed,
y=y,
ivar=ivar,
tvar=tvar_str,
gvar=gvar,
never_treated_values=[0, np.inf],
)
pre_transform_type = 'detrend'
# Estimate pre-treatment effects
pre_treatment_effects = estimate_pre_treatment_effects(
data_transformed=data_pre_transformed,
gvar=gvar,
ivar=ivar,
tvar=tvar_str,
controls=controls,
vce=vce,
cluster_var=cluster_var,
control_strategy=control_group_used,
never_treated_values=[0, np.inf],
alpha=pretreatment_alpha,
estimator=estimator_lower,
transform_type=pre_transform_type,
propensity_controls=ps_controls_final,
trim_threshold=trim_threshold,
)
# Convert to DataFrame
att_pre_treatment_df = pre_treatment_effects_to_dataframe(pre_treatment_effects)
# Run parallel trends test if requested
if pretreatment_test and len(pre_treatment_effects) > 0:
parallel_trends_result = run_parallel_trends_test(
pre_treatment_effects=pre_treatment_effects,
alpha=pretreatment_alpha,
test_type='f',
min_pre_periods=2,
)
# Update data_transformed with pre-treatment columns for visualization
data_transformed = data_pre_transformed
except Exception as e:
warnings.warn(
f"Pre-treatment dynamics estimation failed: {e}. "
f"Continuing without pre-treatment effects.",
UserWarning,
stacklevel=3
)
results_dict = {
'is_staggered': True,
'cohorts': cohorts,
'cohort_sizes': cohort_sizes,
'att_by_cohort_time': att_by_cohort_time,
'att_by_cohort': att_by_cohort,
'att_overall': att_overall,
'se_overall': se_overall,
'cohort_weights': cohort_weights,
'ci_overall_lower': ci_overall[0],
'ci_overall_upper': ci_overall[1],
't_stat_overall': t_stat_overall,
'pvalue_overall': pvalue_overall,
'n_treated': n_treated,
'n_control': n_control,
'nobs': n_treated + n_control,
'control_group': control_group,
'control_group_used': control_group_used,
'aggregate': aggregate,
'estimator': estimator,
'rolling': rolling,
'n_never_treated': n_never_treated,
'alpha': alpha,
'never_treated_values': [0, np.inf],
# Compute ATT with fallback hierarchy:
# 1. att_overall (from aggregate='overall')
# 2. att_cohort_agg (weighted average of cohort effects for aggregate='cohort')
# 3. att_cohort_time_fallback (n_treated-weighted average across cohort-time effects)
# Return None if no valid estimates exist (all NaN) to maintain type consistency
# (None = "no estimate" vs np.nan = "undefined/failed estimate").
'att': (
att_overall if att_overall is not None
else att_cohort_agg if att_cohort_agg is not None
else att_cohort_time_fallback
),
# SE with fallback: se_overall (overall aggregation) -> se_cohort_agg (cohort aggregation)
'se_att': (
se_overall if se_overall is not None
else se_cohort_agg if se_cohort_agg is not None
else np.nan
),
't_stat': (
t_stat_overall if t_stat_overall is not None
else t_stat_cohort_agg if t_stat_cohort_agg is not None
else np.nan
),
'pvalue': (
pvalue_overall if pvalue_overall is not None
else pvalue_cohort_agg if pvalue_cohort_agg is not None
else np.nan
),
'ci_lower': (
ci_overall[0] if ci_overall[0] is not None
else ci_cohort_agg[0] if ci_cohort_agg[0] is not None
else np.nan
),
'ci_upper': (
ci_overall[1] if ci_overall[1] is not None
else ci_cohort_agg[1] if ci_cohort_agg[1] is not None
else np.nan
),
'df_resid': df_resid_val,
'df_inference': df_inference_val,
'vce_type': vce if vce else 'ols',
'params': None,
'bse': None,
'vcov': None,
'resid': None,
# Pre-treatment dynamics
'att_pre_treatment': att_pre_treatment_df,
'parallel_trends_test': parallel_trends_result,
'include_pretreatment': include_pretreatment,
}
metadata = {
'is_staggered': True,
'rolling': rolling,
'control_group': control_group,
'control_group_used': control_group_used,
'aggregate': aggregate,
'estimator': estimator,
'cohorts': cohorts,
'T_max': T_max,
'T_min': T_min,
'has_never_treated': has_never_treated,
'n_never_treated': n_never_treated,
'n_cohorts': len(cohorts),
'vce': vce,
'depvar': y,
'K': 0,
'tpost1': cohorts[0] if cohorts else 0,
'N_treated': n_treated,
'N_control': n_control,
'alpha': alpha,
'ivar': ivar,
'gvar': gvar,
'tvar': tvar,
}
results = LWDIDResults(
results_dict, metadata, att_by_cohort_time,
cohort_time_effects=cohort_time_effects,
)
results.data = data_transformed
# Randomization inference
if ri:
from .staggered.randomization import randomization_inference_staggered
# Handle riseed parameter with type validation.
# Invalid input raises explicit error rather than silently using random seed.
if seed is None and 'riseed' in kwargs:
riseed_val = kwargs['riseed']
if riseed_val is not None:
if isinstance(riseed_val, (int, np.integer)):
seed = int(riseed_val)
elif isinstance(riseed_val, str):
try:
seed = int(riseed_val)
except ValueError:
raise TypeError(
f"riseed must be an integer or integer-convertible string, "
f"got '{riseed_val}'. Use seed parameter for explicit control."
)
else:
raise TypeError(
f"riseed must be an integer, got {type(riseed_val).__name__}. "
f"Use seed parameter for explicit control."
)
# Determine target statistic based on aggregation level.
# For each aggregation type, select the first valid (non-NaN) ATT estimate
# to ensure randomization inference uses a meaningful observed statistic.
if aggregate_lower == 'overall' and att_overall is not None:
ri_target = 'overall'
ri_observed = att_overall
target_cohort_ri = None
target_period_ri = None
elif aggregate_lower == 'cohort' and att_by_cohort is not None and len(att_by_cohort) > 0:
# Select the first cohort with a valid (non-NaN) ATT estimate
valid_cohorts = att_by_cohort[att_by_cohort['att'].notna()]
if len(valid_cohorts) > 0:
ri_target = 'cohort'
ri_observed = valid_cohorts.iloc[0]['att']
target_cohort_ri = int(valid_cohorts.iloc[0]['cohort'])
target_period_ri = None
else:
warnings.warn(
"No valid cohort ATT estimates available for randomization inference. "
"All cohort-level ATT values are NaN.",
UserWarning,
stacklevel=3
)
# Set RI attributes to NaN instead of returning incomplete results.
results.ri_pvalue = np.nan
results.rireps = rireps
results.ri_seed = seed if seed is not None else _generate_ri_seed()
results.ri_method = ri_method
results.ri_valid = 0
results.ri_failed = -1
results.ri_error = "No valid cohort ATT estimates available"
results.ri_target = 'cohort'
return results
else:
ri_target = 'cohort_time'
if len(cohort_time_effects) > 0:
# Select the first cohort-time effect with a valid (non-NaN) ATT
valid_effects = [e for e in cohort_time_effects if pd.notna(e.att)]
if len(valid_effects) > 0:
first_effect = valid_effects[0]
ri_observed = first_effect.att
target_cohort_ri = first_effect.cohort
target_period_ri = first_effect.period
else:
warnings.warn(
"No valid cohort-time ATT estimates available for randomization inference. "
"All cohort-time ATT values are NaN.",
UserWarning,
stacklevel=3
)
# Set RI attributes to NaN instead of returning incomplete results.
results.ri_pvalue = np.nan
results.rireps = rireps
results.ri_seed = seed if seed is not None else _generate_ri_seed()
results.ri_method = ri_method
results.ri_valid = 0
results.ri_failed = -1
results.ri_error = "No valid cohort-time ATT estimates available"
results.ri_target = 'cohort_time'
return results
else:
warnings.warn(
"No available effect estimates, skipping randomization inference.",
UserWarning,
stacklevel=3
)
# Set RI attributes to NaN instead of returning incomplete results.
results.ri_pvalue = np.nan
results.rireps = rireps
results.ri_seed = seed if seed is not None else _generate_ri_seed()
results.ri_method = ri_method
results.ri_valid = 0
results.ri_failed = -1
results.ri_error = "No available effect estimates"
results.ri_target = 'cohort_time'
return results
actual_seed = seed if seed is not None else _generate_ri_seed()
try:
ri_result = randomization_inference_staggered(
data=data,
gvar=gvar,
ivar=ivar,
tvar=tvar_str,
y=y,
cohorts=cohorts,
observed_att=ri_observed,
target=ri_target,
target_cohort=target_cohort_ri,
target_period=target_period_ri,
ri_method=ri_method,
rireps=rireps,
seed=actual_seed,
rolling=rolling,
controls=controls,
vce=vce,
cluster_var=cluster_var,
n_never_treated=n_never_treated,
)
results.ri_pvalue = ri_result.p_value
results.rireps = rireps
results.ri_seed = actual_seed
results.ri_method = ri_result.ri_method
results.ri_valid = ri_result.ri_valid
results.ri_failed = ri_result.ri_failed
results.ri_target = ri_target
except (RandomizationError, ValueError, np.linalg.LinAlgError) as e:
# RandomizationError: RI-specific failures (insufficient data, invalid params)
# ValueError: Data issues during resampling
# LinAlgError: Singular matrix in permuted regressions
warnings.warn(
f"Randomization inference failed: {type(e).__name__}: {e}",
UserWarning,
stacklevel=3
)
results.ri_pvalue = np.nan
results.ri_seed = actual_seed
results.rireps = rireps
results.ri_method = ri_method
results.ri_valid = 0
results.ri_failed = -1
results.ri_error = str(e)
results.ri_target = ri_target
return results