Source code for lwdid.sensitivity

"""
Sensitivity analysis for difference-in-differences estimation.

This module provides tools to assess the robustness of ATT estimates under
varying methodological choices and potential assumption violations. Three
types of sensitivity analysis are supported: pre-treatment period selection
(testing stability across different baseline period configurations),
no-anticipation assumption testing (evaluating robustness by excluding
periods immediately before treatment), and comprehensive analysis combining
multiple robustness checks including transformation method and estimator
comparisons.

Notes
-----
Results are classified into robustness levels based on the sensitivity ratio,
defined as the range of ATT estimates across specifications divided by the
absolute value of the baseline estimate. Thresholds for classification are:
highly robust (< 10%), moderately robust (10-25%), sensitive (25-50%), and
highly sensitive (>= 50%).
"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import warnings

import numpy as np
import pandas as pd
from scipy import stats


# =============================================================================
# Enumerations
# =============================================================================

[docs] class RobustnessLevel(Enum): """ Categorical assessment of estimate stability across specifications. The robustness level is determined by the sensitivity ratio, which measures the range of ATT estimates relative to the baseline estimate magnitude. Attributes ---------- HIGHLY_ROBUST : str Sensitivity ratio below 10%. Estimates are very stable. MODERATELY_ROBUST : str Sensitivity ratio between 10% and 25%. Estimates show minor variation. SENSITIVE : str Sensitivity ratio between 25% and 50%. Estimates vary noticeably. HIGHLY_SENSITIVE : str Sensitivity ratio at or above 50%. Estimates are unstable. """ HIGHLY_ROBUST = "highly_robust" MODERATELY_ROBUST = "moderately_robust" SENSITIVE = "sensitive" HIGHLY_SENSITIVE = "highly_sensitive"
[docs] class AnticipationDetectionMethod(Enum): """ Detection method used to identify potential anticipation effects. Anticipation effects occur when units adjust behavior before formal treatment begins, violating the no-anticipation assumption. Attributes ---------- TREND_BREAK : str Detected via structural break in pre-treatment trend. COEFFICIENT_CHANGE : str Detected via significant change in ATT when excluding periods. PLACEBO_TEST : str Detected via significant placebo effects in pre-treatment periods. NONE_DETECTED : str No anticipation effects identified by any method. INSUFFICIENT_DATA : str Insufficient pre-treatment periods to perform detection. """ TREND_BREAK = "trend_break" COEFFICIENT_CHANGE = "coefficient_change" PLACEBO_TEST = "placebo_test" NONE_DETECTED = "none_detected" INSUFFICIENT_DATA = "insufficient_data"
# ============================================================================= # Data Classes # =============================================================================
[docs] @dataclass class SpecificationResult: """ Result from a single specification in sensitivity analysis. Represents one point in the sensitivity analysis, corresponding to a specific configuration of pre-treatment periods. Attributes ---------- specification_id : int Unique identifier for this specification. n_pre_periods : int Number of pre-treatment periods used. start_period : int First pre-treatment period included. end_period : int Last pre-treatment period included. excluded_periods : int Number of periods excluded before treatment. att : float Average treatment effect on the treated. se : float Standard error of ATT. t_stat : float t-statistic for H0: ATT=0. pvalue : float Two-sided p-value. ci_lower : float Lower bound of confidence interval. ci_upper : float Upper bound of confidence interval. n_treated : int Number of treated units. n_control : int Number of control units. df : int Degrees of freedom for inference. converged : bool Whether estimation converged successfully. spec_warnings : list[str] Warning messages from estimation. """ specification_id: int n_pre_periods: int start_period: int end_period: int excluded_periods: int att: float se: float t_stat: float pvalue: float ci_lower: float ci_upper: float n_treated: int n_control: int df: int converged: bool = True spec_warnings: list[str] = field(default_factory=list) @property def is_significant_05(self) -> bool: """Whether estimate is significant at 5% level.""" return self.pvalue < 0.05 @property def is_significant_10(self) -> bool: """Whether estimate is significant at 10% level.""" return self.pvalue < 0.10
[docs] def to_dict(self) -> dict: """ Convert specification result to dictionary for DataFrame construction. Returns ------- dict Dictionary containing all specification attributes suitable for constructing a pandas DataFrame row. """ return { 'spec_id': self.specification_id, 'n_pre_periods': self.n_pre_periods, 'start_period': self.start_period, 'end_period': self.end_period, 'excluded_periods': self.excluded_periods, 'att': self.att, 'se': self.se, 't_stat': self.t_stat, 'pvalue': self.pvalue, 'ci_lower': self.ci_lower, 'ci_upper': self.ci_upper, 'n_treated': self.n_treated, 'n_control': self.n_control, 'df': self.df, 'significant_05': self.is_significant_05, 'converged': self.converged, }
[docs] @dataclass class AnticipationEstimate: """ ATT estimate with specific anticipation exclusion. Attributes ---------- excluded_periods : int Number of periods excluded before treatment. att : float Average treatment effect on the treated. se : float Standard error of ATT. t_stat : float t-statistic for H0: ATT=0. pvalue : float Two-sided p-value. ci_lower : float Lower bound of confidence interval. ci_upper : float Upper bound of confidence interval. n_pre_periods_used : int Number of pre-treatment periods actually used. """ excluded_periods: int att: float se: float t_stat: float pvalue: float ci_lower: float ci_upper: float n_pre_periods_used: int @property def is_significant(self) -> bool: """Whether estimate is significant at 5% level.""" return self.pvalue < 0.05
[docs] def to_dict(self) -> dict: """ Convert anticipation estimate to dictionary. Returns ------- dict Dictionary containing all estimate attributes suitable for constructing a pandas DataFrame row. """ return { 'excluded_periods': self.excluded_periods, 'att': self.att, 'se': self.se, 't_stat': self.t_stat, 'pvalue': self.pvalue, 'ci_lower': self.ci_lower, 'ci_upper': self.ci_upper, 'n_pre_periods_used': self.n_pre_periods_used, 'significant': self.is_significant, }
[docs] @dataclass class PrePeriodRobustnessResult: """ Result of pre-treatment period robustness analysis. Assesses how ATT estimates vary when using different numbers of pre-treatment periods, helping identify whether findings are robust to this methodological choice. Attributes ---------- specifications : list[SpecificationResult] ATT estimates for each pre-period configuration. baseline_spec : SpecificationResult Estimate using all available pre-treatment periods. att_range : tuple[float, float] (min ATT, max ATT) across all specifications. att_mean : float Mean ATT across specifications. att_std : float Standard deviation of ATT across specifications. sensitivity_ratio : float Ratio of range to baseline: (max - min) / abs(baseline). robustness_level : RobustnessLevel Categorical assessment of robustness. is_robust : bool Whether estimates are stable (ratio < threshold). robustness_threshold : float Threshold used for robustness determination. all_same_sign : bool Whether all estimates have the same sign. all_significant : bool Whether all estimates are significant at 5%. n_significant : int Number of significant specifications. n_sign_changes : int Number of specifications with sign different from baseline. rolling_method : str Transformation method used. estimator : str Estimation method used. n_specifications : int Total number of specifications tested. pre_period_range_tested : tuple[int, int] Range of pre-periods tested (min, max). recommendation : str Main recommendation based on analysis. detailed_recommendations : list[str] Detailed recommendations. result_warnings : list[str] Warning messages. figure : Any | None Matplotlib figure if plot was generated. """ specifications: list[SpecificationResult] baseline_spec: SpecificationResult att_range: tuple[float, float] att_mean: float att_std: float sensitivity_ratio: float robustness_level: RobustnessLevel is_robust: bool robustness_threshold: float all_same_sign: bool all_significant: bool n_significant: int n_sign_changes: int rolling_method: str estimator: str n_specifications: int pre_period_range_tested: tuple[int, int] recommendation: str detailed_recommendations: list[str] = field(default_factory=list) result_warnings: list[str] = field(default_factory=list) figure: Any | None = None
[docs] def to_dataframe(self) -> pd.DataFrame: """ Convert all specification results to a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with one row per specification containing ATT estimates, standard errors, p-values, and other diagnostic information. """ return pd.DataFrame([s.to_dict() for s in self.specifications])
[docs] def get_specification(self, n_pre: int) -> SpecificationResult | None: """ Retrieve specification result for a specific pre-period count. Parameters ---------- n_pre : int Number of pre-treatment periods to look up. Returns ------- SpecificationResult or None The specification result if found, None otherwise. """ for spec in self.specifications: if spec.n_pre_periods == n_pre: return spec return None
[docs] def summary(self) -> str: """ Generate a comprehensive human-readable summary report. Returns ------- str Formatted text report containing configuration, baseline estimates, sensitivity metrics, robustness assessment, and recommendations. """ lines = [ "=" * 75, "PRE-TREATMENT PERIOD ROBUSTNESS ANALYSIS", "=" * 75, "", "CONFIGURATION:", f" Transformation: {self.rolling_method}", f" Estimator: {self.estimator}", f" Pre-period range tested: {self.pre_period_range_tested[0]} - {self.pre_period_range_tested[1]}", f" Number of specifications: {self.n_specifications}", "", "BASELINE ESTIMATE (all pre-periods):", f" ATT = {self.baseline_spec.att:.4f} (SE = {self.baseline_spec.se:.4f})", f" t-stat = {self.baseline_spec.t_stat:.3f}, p-value = {self.baseline_spec.pvalue:.4f}", f" 95% CI: [{self.baseline_spec.ci_lower:.4f}, {self.baseline_spec.ci_upper:.4f}]", "", "SENSITIVITY ANALYSIS:", f" ATT Range: [{self.att_range[0]:.4f}, {self.att_range[1]:.4f}]", f" ATT Mean: {self.att_mean:.4f}", f" ATT Std Dev: {self.att_std:.4f}", f" Sensitivity Ratio: {self.sensitivity_ratio:.1%}", "", "ROBUSTNESS ASSESSMENT:", f" Level: {self.robustness_level.value.replace('_', ' ').title()}", f" Is Robust (ratio < {self.robustness_threshold:.0%}): {'YES ✓' if self.is_robust else 'NO ⚠️'}", f" All Same Sign: {'YES ✓' if self.all_same_sign else 'NO ⚠️'}", f" All Significant: {'YES ✓' if self.all_significant else f'NO ({self.n_significant}/{self.n_specifications})'}", "", ] # Add specification table lines.extend([ "SPECIFICATION DETAILS:", "-" * 70, f"{'N_Pre':>8} {'ATT':>12} {'SE':>10} {'P-value':>10} {'Sig':>6}", "-" * 70, ]) for spec in sorted(self.specifications, key=lambda x: x.n_pre_periods): sig = "***" if spec.pvalue < 0.01 else ("**" if spec.pvalue < 0.05 else ("*" if spec.pvalue < 0.1 else "")) baseline_marker = " (baseline)" if spec.specification_id == self.baseline_spec.specification_id else "" lines.append( f"{spec.n_pre_periods:>8} {spec.att:>12.4f} {spec.se:>10.4f} " f"{spec.pvalue:>10.4f} {sig:>6}{baseline_marker}" ) lines.extend([ "", "─" * 75, "RECOMMENDATION:", f" {self.recommendation}", ]) if self.detailed_recommendations: lines.append("") lines.append("DETAILED RECOMMENDATIONS:") for i, rec in enumerate(self.detailed_recommendations, 1): lines.append(f" {i}. {rec}") if self.result_warnings: lines.extend(["", "WARNINGS:"]) for w in self.result_warnings: lines.append(f" ⚠ {w}") lines.append("=" * 75) return "\n".join(lines)
[docs] def plot( self, show_ci: bool = True, show_baseline: bool = True, figsize: tuple[float, float] = (10, 6), ax: Any = None, ) -> Any: """ Generate sensitivity plot. Shows ATT estimates across different pre-period specifications with confidence intervals. Parameters ---------- show_ci : bool, default True Whether to show confidence intervals. show_baseline : bool, default True Whether to show baseline reference line. figsize : tuple, default (10, 6) Figure size in inches. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. Returns ------- matplotlib.figure.Figure The generated figure. """ import matplotlib.pyplot as plt if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() df = self.to_dataframe().sort_values('n_pre_periods') # Plot estimates with CI if show_ci: ax.errorbar( df['n_pre_periods'], df['att'], yerr=1.96 * df['se'], fmt='o-', capsize=4, label='ATT Estimate', color='steelblue', markersize=8, ) else: ax.plot( df['n_pre_periods'], df['att'], 'o-', label='ATT Estimate', color='steelblue', markersize=8, ) # Highlight baseline if show_baseline: ax.axhline( self.baseline_spec.att, color='red', linestyle='--', alpha=0.7, label=f'Baseline ATT = {self.baseline_spec.att:.3f}', ) # Robustness band (±25% of baseline) band_width = 0.25 * abs(self.baseline_spec.att) if band_width > 0: ax.axhspan( self.baseline_spec.att - band_width, self.baseline_spec.att + band_width, alpha=0.1, color='green', label='±25% Robustness Band', ) # Add zero line ax.axhline(0, color='gray', linestyle='-', alpha=0.3) # Labels and title ax.set_xlabel('Number of Pre-treatment Periods', fontsize=12) ax.set_ylabel('ATT Estimate', fontsize=12) ax.set_title( f'Pre-treatment Period Robustness Analysis\n' f'Sensitivity Ratio: {self.sensitivity_ratio:.1%} ' f'({self.robustness_level.value.replace("_", " ").title()})', fontsize=14, ) ax.legend(loc='best') ax.grid(True, alpha=0.3) plt.tight_layout() self.figure = fig return fig
[docs] @dataclass class NoAnticipationSensitivityResult: """ Result of no-anticipation sensitivity analysis. Tests robustness of ATT estimates to potential anticipation effects by excluding periods immediately before treatment. Attributes ---------- estimates : list[AnticipationEstimate] ATT estimates for each exclusion configuration. baseline_estimate : AnticipationEstimate Estimate with no exclusion (excluded_periods=0). anticipation_detected : bool Whether anticipation effects are detected. recommended_exclusion : int Recommended number of periods to exclude. detection_method : AnticipationDetectionMethod Method used to detect anticipation. recommendation : str Interpretation and recommendations. result_warnings : list[str] Warning messages. figure : Any | None Matplotlib figure if plot was generated. """ estimates: list[AnticipationEstimate] baseline_estimate: AnticipationEstimate anticipation_detected: bool recommended_exclusion: int detection_method: AnticipationDetectionMethod recommendation: str result_warnings: list[str] = field(default_factory=list) figure: Any | None = None
[docs] def to_dataframe(self) -> pd.DataFrame: """ Convert all anticipation estimates to a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with one row per exclusion level containing ATT estimates, standard errors, p-values, and significance indicators. """ return pd.DataFrame([e.to_dict() for e in self.estimates])
[docs] def summary(self) -> str: """ Generate a human-readable summary of the anticipation analysis. Returns ------- str Formatted text report containing estimates by exclusion level, detection results, and recommendations. """ lines = [ "=" * 70, "NO-ANTICIPATION SENSITIVITY ANALYSIS", "=" * 70, "", f"Exclusion range tested: 0 - {max(e.excluded_periods for e in self.estimates)}", "", "Estimates by Exclusion:", "-" * 60, f"{'Excluded':>10} {'ATT':>12} {'SE':>10} {'P-value':>10} {'Sig':>6}", "-" * 60, ] for e in self.estimates: sig = "***" if e.pvalue < 0.01 else ("**" if e.pvalue < 0.05 else ("*" if e.pvalue < 0.1 else "")) lines.append( f"{e.excluded_periods:>10} {e.att:>12.4f} {e.se:>10.4f} " f"{e.pvalue:>10.4f} {sig:>6}" ) lines.extend([ "", f"Anticipation Detected: {'YES ⚠️' if self.anticipation_detected else 'NO ✓'}", f"Detection Method: {self.detection_method.value}", ]) if self.anticipation_detected: lines.append(f"Recommended Exclusion: {self.recommended_exclusion} period(s)") lines.extend([ "", "─" * 70, f"RECOMMENDATION: {self.recommendation}", "─" * 70, ]) if self.result_warnings: lines.extend(["", "WARNINGS:"]) for w in self.result_warnings: lines.append(f" ⚠ {w}") lines.append("=" * 70) return "\n".join(lines)
[docs] def plot( self, show_ci: bool = True, figsize: tuple[float, float] = (10, 6), ax: Any = None, ) -> Any: """ Generate anticipation sensitivity plot. Parameters ---------- show_ci : bool, default True Whether to show confidence intervals. figsize : tuple, default (10, 6) Figure size in inches. ax : matplotlib.axes.Axes, optional Axes to plot on. Returns ------- matplotlib.figure.Figure The generated figure. """ import matplotlib.pyplot as plt if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() df = self.to_dataframe().sort_values('excluded_periods') # Plot estimates if show_ci: ax.errorbar( df['excluded_periods'], df['att'], yerr=1.96 * df['se'], fmt='o-', capsize=4, color='steelblue', markersize=8, label='ATT Estimate', ) else: ax.plot( df['excluded_periods'], df['att'], 'o-', color='steelblue', markersize=8, label='ATT Estimate', ) # Highlight recommended exclusion if self.anticipation_detected and self.recommended_exclusion > 0: rec_est = next( (e for e in self.estimates if e.excluded_periods == self.recommended_exclusion), None ) if rec_est: ax.scatter( [self.recommended_exclusion], [rec_est.att], s=200, facecolors='none', edgecolors='red', linewidths=2, label=f'Recommended (k={self.recommended_exclusion})', zorder=5, ) # Zero line ax.axhline(0, color='gray', linestyle='-', alpha=0.3) # Labels ax.set_xlabel('Number of Excluded Periods Before Treatment', fontsize=12) ax.set_ylabel('ATT Estimate', fontsize=12) ax.set_title( f'No-Anticipation Sensitivity Analysis\n' f'Anticipation Detected: {"Yes" if self.anticipation_detected else "No"}', fontsize=14, ) ax.legend(loc='best') ax.grid(True, alpha=0.3) ax.set_xticks(df['excluded_periods'].values) plt.tight_layout() self.figure = fig return fig
[docs] @dataclass class ComprehensiveSensitivityResult: """ Combined results from comprehensive sensitivity analysis. Attributes ---------- pre_period_result : PrePeriodRobustnessResult | None Results from pre-period robustness analysis. anticipation_result : NoAnticipationSensitivityResult | None Results from no-anticipation sensitivity analysis. transformation_comparison : dict | None Comparison of demean vs detrend results. estimator_comparison : dict | None Comparison across different estimators. overall_assessment : str Overall robustness assessment. recommendations : list[str] List of recommendations. """ pre_period_result: PrePeriodRobustnessResult | None = None anticipation_result: NoAnticipationSensitivityResult | None = None transformation_comparison: dict | None = None estimator_comparison: dict | None = None overall_assessment: str = "" recommendations: list[str] = field(default_factory=list)
[docs] def summary(self) -> str: """ Generate a comprehensive summary of all sensitivity analyses. Returns ------- str Formatted text report containing results from pre-period robustness, anticipation sensitivity, transformation comparison, estimator comparison, overall assessment, and recommendations. """ lines = [ "=" * 70, "COMPREHENSIVE SENSITIVITY ANALYSIS", "=" * 70, "", ] if self.pre_period_result: lines.extend([ "1. Pre-treatment Period Robustness:", f" Robust: {'YES' if self.pre_period_result.is_robust else 'NO'}", f" Sensitivity Ratio: {self.pre_period_result.sensitivity_ratio:.2%}", "", ]) if self.anticipation_result: lines.extend([ "2. No-Anticipation Sensitivity:", f" Anticipation Detected: {'YES' if self.anticipation_result.anticipation_detected else 'NO'}", "", ]) if self.transformation_comparison: lines.extend([ "3. Transformation Comparison (demean vs detrend):", f" Demean ATT: {self.transformation_comparison.get('demean_att', 'N/A'):.4f}", f" Detrend ATT: {self.transformation_comparison.get('detrend_att', 'N/A'):.4f}", f" Difference: {self.transformation_comparison.get('difference', 'N/A'):.4f}", "", ]) if self.estimator_comparison: lines.extend([ "4. Estimator Comparison:", ]) for est, att in self.estimator_comparison.items(): if est != 'range': lines.append(f" {est.upper()}: {att:.4f}") lines.append("") lines.extend([ "─" * 70, f"OVERALL ASSESSMENT: {self.overall_assessment}", "", "RECOMMENDATIONS:", ]) for i, rec in enumerate(self.recommendations, 1): lines.append(f" {i}. {rec}") lines.append("=" * 70) return "\n".join(lines)
[docs] def plot_all(self, figsize: tuple[float, float] = (14, 10)) -> Any: """ Generate combined visualization of all sensitivity analyses. Parameters ---------- figsize : tuple of float, default (14, 10) Figure size in inches (width, height). Returns ------- matplotlib.figure.Figure or None Combined figure with subplots for each available analysis, or None if no results are available to plot. """ import matplotlib.pyplot as plt n_plots = sum([ self.pre_period_result is not None, self.anticipation_result is not None, ]) if n_plots == 0: warnings.warn("No results to plot") return None fig, axes = plt.subplots(1, n_plots, figsize=figsize) if n_plots == 1: axes = [axes] idx = 0 if self.pre_period_result: self.pre_period_result.plot(ax=axes[idx]) idx += 1 if self.anticipation_result: self.anticipation_result.plot(ax=axes[idx]) idx += 1 plt.tight_layout() return fig
# ============================================================================= # Helper Functions # ============================================================================= def _validate_robustness_inputs( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None, d: str | None, post: str | None, rolling: str, ) -> None: """ Validate inputs for robustness analysis. Performs three validation checks: presence of required columns in the DataFrame, validity of the transformation method specification, and consistency of the design mode parameters. Parameters ---------- data : pd.DataFrame Input panel data. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable for staggered designs. d : str or None Treatment indicator for common timing. post : str or None Post-treatment indicator for common timing. rolling : str Transformation method. Raises ------ ValueError If required columns are missing, rolling method is invalid, or design mode parameters are inconsistent (must specify either gvar for staggered designs or both d and post for common timing). """ # Check required columns required = [y, ivar, tvar] if gvar is not None: required.append(gvar) if d is not None: required.append(d) if post is not None: required.append(post) missing = [col for col in required if col not in data.columns] if missing: raise ValueError(f"Missing required columns: {missing}") # Check rolling method valid_rolling = {'demean', 'detrend', 'demeanq', 'detrendq'} if rolling.lower() not in valid_rolling: raise ValueError(f"rolling must be one of {valid_rolling}, got '{rolling}'") # Check mode consistency is_staggered = gvar is not None is_common = d is not None and post is not None if not is_staggered and not is_common: raise ValueError( "Must specify either gvar (staggered) or both d and post (common timing)" ) def _auto_detect_pre_period_range( data: pd.DataFrame, ivar: str, tvar: str, gvar: str | None, d: str | None, post: str | None, rolling: str, ) -> tuple[int, int]: """ Automatically detect valid pre-treatment period range. Returns (min_pre, max_pre) where: - min_pre: Minimum required for the transformation method - max_pre: Maximum available in the data Parameters ---------- data : pd.DataFrame Input panel data. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable for staggered designs. d : str or None Treatment indicator for common timing. post : str or None Post-treatment indicator for common timing. rolling : str Transformation method. Returns ------- tuple[int, int] (min_pre_periods, max_pre_periods) """ # Minimum requirements by method min_required = { 'demean': 1, 'detrend': 2, 'demeanq': 1, 'detrendq': 2, } min_pre = min_required.get(rolling.lower(), 1) # Detect maximum available if gvar is not None: # Staggered: find minimum pre-periods across cohorts cohorts = data[gvar].dropna().unique() cohorts = [c for c in cohorts if c > 0 and np.isfinite(c)] if not cohorts: return (min_pre, min_pre) max_pre_by_cohort = [] min_time = data[tvar].min() for cohort in cohorts: max_pre_by_cohort.append(int(cohort - min_time)) max_pre = min(max_pre_by_cohort) if max_pre_by_cohort else min_pre else: # Common timing: count pre-treatment periods pre_data = data[data[post] == 0] max_pre = pre_data[tvar].nunique() # Ensure valid range max_pre = max(max_pre, min_pre) return (min_pre, max_pre) def _get_max_pre_periods( data: pd.DataFrame, ivar: str, tvar: str, gvar: str | None, post: str | None, ) -> int: """ Determine the maximum number of pre-treatment periods available. For staggered designs, returns the minimum across all cohorts to ensure all cohorts have sufficient pre-treatment data. For common timing, counts the number of unique pre-treatment time periods. Parameters ---------- data : pd.DataFrame Input panel data. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable for staggered designs. post : str or None Post-treatment indicator for common timing. Returns ------- int Maximum number of pre-treatment periods available. """ if gvar is not None: cohorts = data[gvar].dropna().unique() cohorts = [c for c in cohorts if c > 0 and np.isfinite(c)] if not cohorts: return 0 min_time = data[tvar].min() max_pre_by_cohort = [int(c - min_time) for c in cohorts] return min(max_pre_by_cohort) if max_pre_by_cohort else 0 else: pre_data = data[data[post] == 0] return pre_data[tvar].nunique() def _filter_to_n_pre_periods( data: pd.DataFrame, ivar: str, tvar: str, gvar: str | None, d: str | None, post: str | None, n_pre_periods: int, exclude_periods: int, ) -> pd.DataFrame: """ Filter data to use only specified number of pre-treatment periods. For staggered designs, this is done cohort-by-cohort. For common timing, this filters the entire dataset. Parameters ---------- data : pd.DataFrame Input panel data. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable for staggered designs. d : str or None Treatment indicator for common timing. post : str or None Post-treatment indicator for common timing. n_pre_periods : int Number of pre-treatment periods to keep. exclude_periods : int Number of periods to exclude before treatment. Returns ------- pd.DataFrame Filtered data with specified pre-treatment periods. """ data = data.copy() if gvar is not None: # Staggered: filter by cohort filtered_dfs = [] cohorts = data[gvar].dropna().unique() cohorts = [c for c in cohorts if c > 0 and np.isfinite(c)] for cohort in cohorts: cohort_mask = data[gvar] == cohort cohort_data = data[cohort_mask].copy() # Determine pre-treatment period range for this cohort treatment_period = cohort pre_end = treatment_period - 1 - exclude_periods pre_start = pre_end - n_pre_periods + 1 # Keep only specified pre-periods and all post-periods time_mask = ( (cohort_data[tvar] >= pre_start) & (cohort_data[tvar] <= pre_end) ) | (cohort_data[tvar] >= treatment_period) filtered_dfs.append(cohort_data[time_mask]) # Also include never-treated units (all their periods) never_treated_mask = ( data[gvar].isna() | (data[gvar] == 0) | (data[gvar] == np.inf) ) if never_treated_mask.any(): # For never-treated, keep periods that align with treated cohorts never_data = data[never_treated_mask].copy() if cohorts: min_cohort = min(cohorts) pre_end = min_cohort - 1 - exclude_periods pre_start = pre_end - n_pre_periods + 1 time_mask = (never_data[tvar] >= pre_start) filtered_dfs.append(never_data[time_mask]) else: filtered_dfs.append(never_data) if filtered_dfs: return pd.concat(filtered_dfs, ignore_index=True) # Return empty DataFrame preserving schema for downstream compatibility return data.iloc[0:0] else: # Common timing: simpler filtering pre_data = data[data[post] == 0] post_data = data[data[post] != 0] # Get all pre-treatment times and select the last n_pre_periods pre_times = sorted(pre_data[tvar].unique()) if exclude_periods > 0: pre_times = pre_times[:-exclude_periods] if len(pre_times) < n_pre_periods: # Not enough periods, use all available selected_pre_times = pre_times else: selected_pre_times = pre_times[-n_pre_periods:] # Filter pre-treatment data filtered_pre = pre_data[pre_data[tvar].isin(selected_pre_times)] return pd.concat([filtered_pre, post_data], ignore_index=True) def _filter_excluding_periods( data: pd.DataFrame, ivar: str, tvar: str, gvar: str | None, post: str | None, exclude_periods: int, ) -> pd.DataFrame: """ Filter data excluding specified periods before treatment. Removes the specified number of periods immediately preceding treatment from the pre-treatment baseline. For common timing designs, re-encodes the time variable to maintain continuity after filtering. Parameters ---------- data : pd.DataFrame Input panel data. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable for staggered designs. post : str or None Post-treatment indicator for common timing. exclude_periods : int Number of periods to exclude before treatment. Returns ------- pd.DataFrame Filtered data with excluded periods removed and time re-encoded. Notes ----- Time re-encoding for common timing designs prevents discontinuity errors that would otherwise occur when the excluded periods create gaps in the time sequence. For staggered designs, re-encoding is not performed due to the complexity of cohort-specific time structures. """ if exclude_periods == 0: return data.copy() data = data.copy() if gvar is not None: # Staggered: exclude by cohort filtered_dfs = [] cohorts = data[gvar].dropna().unique() cohorts = [c for c in cohorts if c > 0 and np.isfinite(c)] for cohort in cohorts: cohort_mask = data[gvar] == cohort cohort_data = data[cohort_mask].copy() # Exclude periods [g - exclude_periods, g - 1] excluded_times = list(range(int(cohort - exclude_periods), int(cohort))) time_mask = ~cohort_data[tvar].isin(excluded_times) filtered_dfs.append(cohort_data[time_mask]) # Include never-treated units (no exclusion needed) never_treated_mask = ( data[gvar].isna() | (data[gvar] == 0) | (data[gvar] == np.inf) ) if never_treated_mask.any(): filtered_dfs.append(data[never_treated_mask]) if filtered_dfs: result = pd.concat(filtered_dfs, ignore_index=True) # Re-encode time to be continuous for staggered # Note: This is complex for staggered, so we skip re-encoding # and let the caller handle potential discontinuity return result return data.iloc[0:0] else: # Common timing pre_data = data[data[post] == 0] post_data = data[data[post] != 0] # Get pre-treatment times and exclude the last `exclude_periods` pre_times = sorted(pre_data[tvar].unique()) if exclude_periods >= len(pre_times): warnings.warn( f"Cannot exclude {exclude_periods} periods: only {len(pre_times)} " f"pre-treatment periods available." ) return data.copy() excluded_times = pre_times[-exclude_periods:] filtered_pre = pre_data[~pre_data[tvar].isin(excluded_times)] # Combine filtered pre and post data result = pd.concat([filtered_pre, post_data], ignore_index=True) # Re-encode time variable to be continuous # This avoids TimeDiscontinuityError in lwdid() remaining_times = sorted(result[tvar].unique()) time_mapping = {old_t: new_t for new_t, old_t in enumerate(remaining_times, start=1)} result[tvar] = result[tvar].map(time_mapping) return result def _determine_robustness_level(sensitivity_ratio: float) -> RobustnessLevel: """ Determine robustness level based on sensitivity ratio. Parameters ---------- sensitivity_ratio : float Ratio of ATT range to baseline ATT. Returns ------- RobustnessLevel Categorical robustness assessment. """ if sensitivity_ratio < 0.10: return RobustnessLevel.HIGHLY_ROBUST elif sensitivity_ratio < 0.25: return RobustnessLevel.MODERATELY_ROBUST elif sensitivity_ratio < 0.50: return RobustnessLevel.SENSITIVE else: return RobustnessLevel.HIGHLY_SENSITIVE def _compute_sensitivity_ratio( atts: list[float], baseline_att: float, ) -> float: """ Compute sensitivity ratio measuring estimate variability. Parameters ---------- atts : list[float] List of ATT estimates across specifications. baseline_att : float Baseline ATT estimate used for normalization. Returns ------- float Sensitivity ratio, or infinity if baseline is near zero but range is positive, or zero if both baseline and range are near zero. Notes ----- The sensitivity ratio is defined as the range of ATT estimates divided by the absolute value of the baseline estimate: .. math:: \\text{ratio} = \\frac{\\max(ATT) - \\min(ATT)}{|ATT_{baseline}|} A ratio of 0.25 indicates the estimate range spans 25% of the baseline magnitude. Lower ratios indicate greater stability across specifications. """ if not atts: return 0.0 att_range = max(atts) - min(atts) if abs(baseline_att) > 1e-10: return att_range / abs(baseline_att) else: return float('inf') if att_range > 1e-10 else 0.0 def _generate_robustness_recommendations( specifications: list[SpecificationResult], baseline_spec: SpecificationResult, sensitivity_ratio: float, is_robust: bool, all_same_sign: bool, all_significant: bool, rolling: str, ) -> tuple[str, list[str], list[str]]: """ Generate recommendations based on robustness analysis. Parameters ---------- specifications : list[SpecificationResult] All specification results. baseline_spec : SpecificationResult Baseline specification result. sensitivity_ratio : float Computed sensitivity ratio. is_robust : bool Whether results are robust. all_same_sign : bool Whether all estimates have same sign. all_significant : bool Whether all estimates are significant. rolling : str Transformation method used. Returns ------- tuple[str, list[str], list[str]] (main_recommendation, detailed_recommendations, warnings) """ recommendations = [] result_warnings = [] # Main recommendation if is_robust and all_same_sign and all_significant: main_rec = ( "Results are robust to pre-treatment period selection. " "The ATT estimate is stable across specifications." ) elif is_robust and all_same_sign: main_rec = ( "Results are moderately robust. Sign is consistent but " "significance varies across specifications." ) recommendations.append( "Consider reporting the range of estimates for transparency." ) elif not all_same_sign: main_rec = ( "CAUTION: Results are sensitive to pre-treatment period selection. " "Sign changes detected across specifications." ) result_warnings.append("Sign change detected - interpret results with caution.") recommendations.append( "Investigate why estimates change sign with different pre-periods." ) recommendations.append( "Consider using detrend method if trends may be heterogeneous." ) else: main_rec = ( f"Results show moderate sensitivity (ratio = {sensitivity_ratio:.1%}). " "Consider additional robustness checks." ) # Method-specific recommendations if rolling.lower() == 'demean' and sensitivity_ratio > 0.25: recommendations.append( "High sensitivity with demean suggests potential heterogeneous trends. " "Consider using rolling='detrend' instead." ) if rolling.lower() == 'detrend' and sensitivity_ratio > 0.50: recommendations.append( "High sensitivity even with detrend suggests potential model " "misspecification or data quality issues." ) # Monotonic pattern may indicate time-varying confounding converged_specs = [s for s in specifications if s.converged] atts = [s.att for s in sorted(converged_specs, key=lambda x: x.n_pre_periods)] if len(atts) >= 3: # Consistent increase or decrease across specifications is suspicious diffs = np.diff(atts) if len(diffs) > 0 and (all(d > 0 for d in diffs) or all(d < 0 for d in diffs)): result_warnings.append( "ATT estimates show monotonic trend with pre-period count. " "This may indicate time-varying confounding." ) return main_rec, recommendations, result_warnings def _run_single_specification( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None, d: str | None, post: str | None, rolling: str, estimator: str, controls: list[str] | None, vce: str | None, cluster_var: str | None, n_pre_periods: int, exclude_periods: int, alpha: float, spec_id: int, ) -> SpecificationResult: """ Run estimation for a single specification. Filters data to use only the specified number of pre-treatment periods and runs lwdid estimation. Parameters ---------- data : pd.DataFrame Input panel data. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str or None Cohort variable for staggered designs. d : str or None Treatment indicator for common timing. post : str or None Post-treatment indicator for common timing. rolling : str Transformation method. estimator : str Estimation method. controls : list[str] or None Control variables. vce : str or None Variance estimator type. cluster_var : str or None Cluster variable. n_pre_periods : int Number of pre-treatment periods to use. exclude_periods : int Number of periods to exclude before treatment. alpha : float Significance level. spec_id : int Specification identifier. Returns ------- SpecificationResult Result for this specification. """ from .core import lwdid try: # Filter data to specified pre-period range filtered_data = _filter_to_n_pre_periods( data=data, ivar=ivar, tvar=tvar, gvar=gvar, d=d, post=post, n_pre_periods=n_pre_periods, exclude_periods=exclude_periods, ) if len(filtered_data) == 0: raise ValueError("No data remaining after filtering") # Determine start and end periods if gvar is not None: # Staggered: use minimum cohort cohorts = filtered_data[gvar].dropna().unique() cohorts = [c for c in cohorts if c > 0 and np.isfinite(c)] if cohorts: min_cohort = min(cohorts) end_period = int(min_cohort - 1 - exclude_periods) start_period = end_period - n_pre_periods + 1 else: start_period = end_period = 0 else: pre_mask = filtered_data[post] == 0 pre_times = filtered_data.loc[pre_mask, tvar].unique() if len(pre_times) > 0: start_period = int(min(pre_times)) end_period = int(max(pre_times)) else: start_period = end_period = 0 # Run estimation result = lwdid( data=filtered_data, y=y, d=d, ivar=ivar, tvar=tvar, post=post, gvar=gvar, rolling=rolling, estimator=estimator, controls=controls, vce=vce, cluster_var=cluster_var, alpha=alpha, ) return SpecificationResult( specification_id=spec_id, n_pre_periods=n_pre_periods, start_period=start_period, end_period=end_period, excluded_periods=exclude_periods, att=result.att, se=result.se_att, t_stat=result.t_stat, pvalue=result.pvalue, ci_lower=result.ci_lower, ci_upper=result.ci_upper, n_treated=result.n_treated, n_control=result.n_control, df=result.df_inference, converged=True, spec_warnings=[], ) except Exception as e: warnings.warn(f"Specification {spec_id} (n_pre={n_pre_periods}) failed: {e}") return SpecificationResult( specification_id=spec_id, n_pre_periods=n_pre_periods, start_period=0, end_period=0, excluded_periods=exclude_periods, att=np.nan, se=np.nan, t_stat=np.nan, pvalue=np.nan, ci_lower=np.nan, ci_upper=np.nan, n_treated=0, n_control=0, df=0, converged=False, spec_warnings=[str(e)], ) def _detect_anticipation_effects( estimates: list[AnticipationEstimate], baseline: AnticipationEstimate, threshold: float, ) -> tuple[bool, int, AnticipationDetectionMethod]: """ Detect anticipation effects from sensitivity analysis results. Applies two detection methods sequentially to identify potential violations of the no-anticipation assumption. Parameters ---------- estimates : list[AnticipationEstimate] Estimates for each exclusion level, ordered by exclusion count. baseline : AnticipationEstimate Baseline estimate with no period exclusion. threshold : float Detection threshold for relative change in ATT. Returns ------- detected : bool Whether anticipation effects were detected. recommended_exclusion : int Recommended number of periods to exclude if detected. method : AnticipationDetectionMethod Detection method that identified the effect. Notes ----- Two detection methods are applied: 1. Coefficient change method: Flags anticipation if ATT increases in magnitude by more than the threshold when excluding periods. This pattern suggests pre-treatment periods were biasing estimates toward zero. 2. Trend break method: Flags anticipation if ATT magnitude increases monotonically with exclusion count, then stabilizes. The recommended exclusion is set where the rate of increase drops by at least 50%. """ valid_estimates = [e for e in estimates if not np.isnan(e.att)] if len(valid_estimates) < 2: return False, 0, AnticipationDetectionMethod.INSUFFICIENT_DATA # Method 1: Check for significant coefficient change baseline_att = baseline.att for est in valid_estimates[1:]: # Skip baseline if abs(baseline_att) > 1e-10: relative_change = abs(est.att - baseline_att) / abs(baseline_att) if relative_change > threshold: # Check if change is in expected direction # (anticipation typically biases toward zero) if abs(est.att) > abs(baseline_att): return True, est.excluded_periods, AnticipationDetectionMethod.COEFFICIENT_CHANGE # Method 2: Check for monotonic pattern suggesting anticipation atts = [e.att for e in valid_estimates] if len(atts) >= 3: # If ATT magnitude increases monotonically with exclusion, # suggests anticipation was biasing estimates toward zero abs_atts = [abs(a) for a in atts] if all(abs_atts[i] <= abs_atts[i+1] for i in range(len(abs_atts)-1)): # Find where the increase stabilizes for i in range(1, len(abs_atts)): if i < len(abs_atts) - 1: current_increase = abs_atts[i] - abs_atts[i-1] next_increase = abs_atts[i+1] - abs_atts[i] if current_increase > 0 and next_increase < current_increase * 0.5: return True, i, AnticipationDetectionMethod.TREND_BREAK return False, 0, AnticipationDetectionMethod.NONE_DETECTED # ============================================================================= # Main Public Functions # =============================================================================
[docs] def robustness_pre_periods( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None = None, d: str | None = None, post: str | None = None, rolling: str = 'demean', estimator: str = 'ra', controls: list[str] | None = None, vce: str | None = None, cluster_var: str | None = None, pre_period_range: tuple[int, int] | None = None, step: int = 1, exclude_periods_before_treatment: int = 0, robustness_threshold: float = 0.25, alpha: float = 0.05, verbose: bool = True, ) -> PrePeriodRobustnessResult: """ Assess robustness of ATT estimates to pre-treatment period selection. Tests how ATT estimates vary when using different numbers of pre-treatment periods, allowing researchers to assess whether findings are robust to this methodological choice. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str, optional Cohort variable for staggered designs. d : str, optional Treatment indicator for common timing. post : str, optional Post-treatment indicator for common timing. rolling : {'demean', 'detrend'}, default 'demean' Transformation method. estimator : {'ra', 'ipw', 'ipwra', 'psm'}, default 'ra' Estimation method. controls : list of str, optional Control variable column names. vce : str, optional Variance estimator type. cluster_var : str, optional Cluster variable for clustered SE. pre_period_range : tuple of (int, int), optional Range of pre-treatment periods to test (min_periods, max_periods). If None, automatically determined from data. step : int, default 1 Step size for varying pre-treatment periods. exclude_periods_before_treatment : int, default 0 Number of periods to exclude immediately before treatment. Useful for testing robustness to no-anticipation violations. robustness_threshold : float, default 0.25 Threshold for robustness determination. Results are considered robust if sensitivity_ratio < robustness_threshold. alpha : float, default 0.05 Significance level for confidence intervals. verbose : bool, default True Whether to print progress and summary. Returns ------- PrePeriodRobustnessResult Results containing: - specifications: ATT estimates for each pre-period count - sensitivity_ratio: Range of ATT estimates relative to baseline - is_robust: Whether estimates are stable across specifications - recommendation: Interpretation and recommendations - figure: Sensitivity plot (if plot() called) Notes ----- The function varies the starting point of pre-treatment data and re-estimates ATT for each specification, allowing researchers to assess how sensitive their findings are to this methodological choice. In many applications, the policy intervention may be based on past outcomes. This analysis helps determine whether sufficient pre-treatment periods are being used to adequately control for selection into treatment. Robustness levels based on sensitivity ratio: - < 10%: Highly robust - 10-25%: Moderately robust - 25-50%: Sensitive - >= 50%: Highly sensitive See Also -------- lwdid : Main estimation function. sensitivity_no_anticipation : Test robustness to anticipation effects. sensitivity_analysis : Comprehensive sensitivity analysis. """ # 1. Validate inputs _validate_robustness_inputs(data, y, ivar, tvar, gvar, d, post, rolling) # 2. Determine pre-period range if pre_period_range is None: pre_period_range = _auto_detect_pre_period_range( data, ivar, tvar, gvar, d, post, rolling ) min_pre, max_pre = pre_period_range if verbose: print(f"Pre-treatment period robustness analysis") print(f"Testing pre-period range: {min_pre} to {max_pre}") print("-" * 50) # 3. Generate specification list n_pre_values = list(range(min_pre, max_pre + 1, step)) if len(n_pre_values) < 2: warnings.warn( f"Only {len(n_pre_values)} specification(s) possible. " "Consider expanding pre_period_range or reducing step." ) # 4. Run estimations for each specification specifications = [] for i, n_pre in enumerate(n_pre_values): if verbose: print(f"Running specification {i+1}/{len(n_pre_values)}: n_pre={n_pre}") spec_result = _run_single_specification( data=data, y=y, ivar=ivar, tvar=tvar, gvar=gvar, d=d, post=post, rolling=rolling, estimator=estimator, controls=controls, vce=vce, cluster_var=cluster_var, n_pre_periods=n_pre, exclude_periods=exclude_periods_before_treatment, alpha=alpha, spec_id=i, ) specifications.append(spec_result) # 5. Identify baseline (maximum pre-periods) converged_specs = [s for s in specifications if s.converged] if not converged_specs: raise ValueError("All specifications failed to converge") baseline_spec = max(converged_specs, key=lambda x: x.n_pre_periods) # 6. Compute sensitivity metrics atts = [s.att for s in converged_specs] att_range = (min(atts), max(atts)) att_mean = float(np.mean(atts)) att_std = float(np.std(atts, ddof=1)) if len(atts) > 1 else 0.0 # Sensitivity ratio: range / |baseline| sensitivity_ratio = _compute_sensitivity_ratio(atts, baseline_spec.att) # 7. Assess robustness robustness_level = _determine_robustness_level(sensitivity_ratio) is_robust = sensitivity_ratio < robustness_threshold # 8. Check sign and significance stability baseline_sign = np.sign(baseline_spec.att) all_same_sign = all(np.sign(s.att) == baseline_sign for s in converged_specs) n_significant = sum(1 for s in converged_specs if s.is_significant_05) all_significant = n_significant == len(converged_specs) n_sign_changes = sum(1 for s in converged_specs if np.sign(s.att) != baseline_sign) # 9. Generate recommendations recommendation, detailed_recs, result_warnings = _generate_robustness_recommendations( specifications=specifications, baseline_spec=baseline_spec, sensitivity_ratio=sensitivity_ratio, is_robust=is_robust, all_same_sign=all_same_sign, all_significant=all_significant, rolling=rolling, ) # 10. Create result object result = PrePeriodRobustnessResult( specifications=specifications, baseline_spec=baseline_spec, att_range=att_range, att_mean=att_mean, att_std=att_std, sensitivity_ratio=sensitivity_ratio, robustness_level=robustness_level, is_robust=is_robust, robustness_threshold=robustness_threshold, all_same_sign=all_same_sign, all_significant=all_significant, n_significant=n_significant, n_sign_changes=n_sign_changes, rolling_method=rolling, estimator=estimator, n_specifications=len(specifications), pre_period_range_tested=pre_period_range, recommendation=recommendation, detailed_recommendations=detailed_recs, result_warnings=result_warnings, ) # 11. Print summary if verbose if verbose: print() print(result.summary()) return result
[docs] def sensitivity_no_anticipation( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None = None, d: str | None = None, post: str | None = None, rolling: str = 'demean', estimator: str = 'ra', controls: list[str] | None = None, vce: str | None = None, cluster_var: str | None = None, max_anticipation: int = 3, detection_threshold: float = 0.10, alpha: float = 0.05, verbose: bool = True, ) -> NoAnticipationSensitivityResult: """ Test robustness of ATT estimates to potential anticipation effects. When the no-anticipation assumption may be violated (e.g., policy announced before implementation), units may adjust behavior before formal treatment. This function tests robustness by excluding periods immediately before treatment from the pre-treatment baseline. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str, optional Cohort variable for staggered designs. d : str, optional Treatment indicator for common timing. post : str, optional Post-treatment indicator for common timing. rolling : {'demean', 'detrend'}, default 'demean' Transformation method. estimator : {'ra', 'ipw', 'ipwra', 'psm'}, default 'ra' Estimation method. controls : list of str, optional Control variable column names. vce : str, optional Variance estimator type. cluster_var : str, optional Cluster variable for clustered SE. max_anticipation : int, default 3 Maximum number of periods to test for anticipation effects. Tests excluding 0, 1, 2, ..., max_anticipation periods. detection_threshold : float, default 0.10 Threshold for detecting anticipation effects. If relative change in ATT exceeds this threshold, anticipation is detected. alpha : float, default 0.05 Significance level for confidence intervals. verbose : bool, default True Whether to print progress and summary. Returns ------- NoAnticipationSensitivityResult Results containing: - estimates: ATT estimates for each exclusion count - anticipation_detected: Whether anticipation effects are detected - recommended_exclusion: Recommended number of periods to exclude - figure: Sensitivity plot (if plot() called) Notes ----- The no-anticipation assumption requires that, prior to the first intervention period for a given treatment cohort, the potential outcomes are the same (on average) as in the never treated state. If policy is announced k periods before implementation, units may adjust behavior during periods {g-k, ..., g-1}. By excluding these periods from the pre-treatment baseline, we can test whether estimates are robust to such anticipation effects. See Also -------- robustness_pre_periods : General pre-period robustness check. sensitivity_analysis : Comprehensive sensitivity analysis. """ from .core import lwdid # Validate inputs _validate_robustness_inputs(data, y, ivar, tvar, gvar, d, post, rolling) # Determine maximum feasible exclusion min_required = 2 if rolling.lower() in ('detrend', 'detrendq') else 1 max_available = _get_max_pre_periods(data, ivar, tvar, gvar, post) max_feasible_exclusion = max(0, max_available - min_required) max_anticipation = min(max_anticipation, max_feasible_exclusion) if max_anticipation < 1: warnings.warn( "Insufficient pre-treatment periods for anticipation analysis. " f"Need at least {min_required + 1} pre-periods, have {max_available}." ) if verbose: print(f"No-anticipation sensitivity analysis") print(f"Testing exclusion range: 0 to {max_anticipation}") print("-" * 50) # Run estimations for each exclusion level estimates = [] result_warnings = [] for exclude in range(max_anticipation + 1): if verbose: print(f"Testing exclusion = {exclude} periods...") try: # Filter data filtered_data = _filter_excluding_periods( data, ivar, tvar, gvar, post, exclude ) if len(filtered_data) == 0: raise ValueError("No data remaining after filtering") # Run estimation result = lwdid( data=filtered_data, y=y, d=d, ivar=ivar, tvar=tvar, post=post, gvar=gvar, rolling=rolling, estimator=estimator, controls=controls, vce=vce, cluster_var=cluster_var, alpha=alpha, ) # Calculate n_pre_periods_used n_pre_used = max_available - exclude estimates.append(AnticipationEstimate( excluded_periods=exclude, att=result.att, se=result.se_att, t_stat=result.t_stat, pvalue=result.pvalue, ci_lower=result.ci_lower, ci_upper=result.ci_upper, n_pre_periods_used=n_pre_used, )) except Exception as e: warnings.warn(f"Exclusion {exclude} failed: {e}") result_warnings.append(f"Exclusion {exclude} failed: {e}") estimates.append(AnticipationEstimate( excluded_periods=exclude, att=np.nan, se=np.nan, t_stat=np.nan, pvalue=np.nan, ci_lower=np.nan, ci_upper=np.nan, n_pre_periods_used=0, )) # Identify baseline (no exclusion) baseline = estimates[0] if estimates else AnticipationEstimate( excluded_periods=0, att=np.nan, se=np.nan, t_stat=np.nan, pvalue=np.nan, ci_lower=np.nan, ci_upper=np.nan, n_pre_periods_used=0 ) # Detect anticipation effects anticipation_detected, recommended_exclusion, detection_method = \ _detect_anticipation_effects(estimates, baseline, detection_threshold) # Generate recommendation if anticipation_detected: recommendation = ( f"Anticipation effects detected. Consider excluding " f"{recommended_exclusion} period(s) before treatment. " f"Use lwdid(..., exclude_pre_periods={recommended_exclusion})." ) else: recommendation = ( "No significant anticipation effects detected. " "The no-anticipation assumption appears reasonable." ) result = NoAnticipationSensitivityResult( estimates=estimates, baseline_estimate=baseline, anticipation_detected=anticipation_detected, recommended_exclusion=recommended_exclusion, detection_method=detection_method, recommendation=recommendation, result_warnings=result_warnings, ) if verbose: print() print(result.summary()) return result
[docs] def sensitivity_analysis( data: pd.DataFrame, y: str, ivar: str, tvar: str, gvar: str | None = None, d: str | None = None, post: str | None = None, rolling: str = 'demean', estimator: str = 'ra', controls: list[str] | None = None, vce: str | None = None, cluster_var: str | None = None, analyses: list[str] | None = None, alpha: float = 0.05, verbose: bool = True, ) -> ComprehensiveSensitivityResult: """ Perform comprehensive sensitivity analysis for DiD estimation. Combines multiple robustness checks into a single analysis, providing an overall assessment of estimate reliability across different methodological choices. Parameters ---------- data : pd.DataFrame Panel data in long format. y : str Outcome variable column name. ivar : str Unit identifier column name. tvar : str Time variable column name. gvar : str, optional Cohort variable for staggered designs. d : str, optional Treatment indicator for common timing. post : str, optional Post-treatment indicator for common timing. rolling : {'demean', 'detrend'}, default 'demean' Primary transformation method. estimator : {'ra', 'ipw', 'ipwra', 'psm'}, default 'ra' Primary estimation method. controls : list of str, optional Control variable column names. vce : str, optional Variance estimator type. cluster_var : str, optional Cluster variable for clustered SE. analyses : list of str, optional Which analyses to run. Default: all. Options: 'pre_periods', 'anticipation', 'transformation', 'estimator' alpha : float, default 0.05 Significance level. verbose : bool, default True Whether to print progress and summary. Returns ------- ComprehensiveSensitivityResult Combined results from all sensitivity analyses. Notes ----- Four types of sensitivity analysis are available: 1. **Pre-periods**: Tests stability across different numbers of pre-treatment periods used in the transformation. 2. **Anticipation**: Tests robustness to potential anticipation effects by excluding periods immediately before treatment. 3. **Transformation**: Compares demean and detrend methods to assess whether heterogeneous trends may be present. 4. **Estimator**: Compares RA, IPW, and IPWRA estimators to check robustness to propensity score or outcome model misspecification. See Also -------- robustness_pre_periods : Pre-period robustness check. sensitivity_no_anticipation : Anticipation sensitivity check. """ from .core import lwdid # Default: run all analyses if analyses is None: analyses = ['pre_periods', 'anticipation', 'transformation', 'estimator'] # Validate inputs _validate_robustness_inputs(data, y, ivar, tvar, gvar, d, post, rolling) if verbose: print("=" * 70) print("COMPREHENSIVE SENSITIVITY ANALYSIS") print("=" * 70) print() pre_period_result = None anticipation_result = None transformation_comparison = None estimator_comparison = None recommendations = [] # 1. Pre-treatment period robustness if 'pre_periods' in analyses: if verbose: print("1. Running pre-treatment period robustness analysis...") print() try: pre_period_result = robustness_pre_periods( data=data, y=y, ivar=ivar, tvar=tvar, gvar=gvar, d=d, post=post, rolling=rolling, estimator=estimator, controls=controls, vce=vce, cluster_var=cluster_var, alpha=alpha, verbose=False, ) if not pre_period_result.is_robust: recommendations.append( f"Pre-period sensitivity detected (ratio={pre_period_result.sensitivity_ratio:.1%}). " "Consider using detrend method or investigating data quality." ) except Exception as e: warnings.warn(f"Pre-period analysis failed: {e}") # 2. No-anticipation sensitivity if 'anticipation' in analyses: if verbose: print("2. Running no-anticipation sensitivity analysis...") print() try: anticipation_result = sensitivity_no_anticipation( data=data, y=y, ivar=ivar, tvar=tvar, gvar=gvar, d=d, post=post, rolling=rolling, estimator=estimator, controls=controls, vce=vce, cluster_var=cluster_var, alpha=alpha, verbose=False, ) if anticipation_result.anticipation_detected: recommendations.append( f"Anticipation effects detected. Consider excluding " f"{anticipation_result.recommended_exclusion} period(s) before treatment." ) except Exception as e: warnings.warn(f"Anticipation analysis failed: {e}") # 3. Transformation comparison if 'transformation' in analyses: if verbose: print("3. Comparing transformation methods (demean vs detrend)...") print() try: # Check if detrend is feasible min_pre_detrend = 2 max_pre = _get_max_pre_periods(data, ivar, tvar, gvar, post) if max_pre >= min_pre_detrend: result_demean = lwdid( data=data, y=y, d=d, ivar=ivar, tvar=tvar, post=post, gvar=gvar, rolling='demean', estimator=estimator, controls=controls, vce=vce, cluster_var=cluster_var, alpha=alpha, ) result_detrend = lwdid( data=data, y=y, d=d, ivar=ivar, tvar=tvar, post=post, gvar=gvar, rolling='detrend', estimator=estimator, controls=controls, vce=vce, cluster_var=cluster_var, alpha=alpha, ) transformation_comparison = { 'demean_att': result_demean.att, 'demean_se': result_demean.se_att, 'detrend_att': result_detrend.att, 'detrend_se': result_detrend.se_att, 'difference': abs(result_demean.att - result_detrend.att), } # Check if difference is substantial if abs(result_demean.att) > 1e-10: rel_diff = abs(result_demean.att - result_detrend.att) / abs(result_demean.att) if rel_diff > 0.25: recommendations.append( f"Substantial difference between demean and detrend ({rel_diff:.1%}). " "This suggests heterogeneous trends may be present." ) else: warnings.warn( f"Insufficient pre-periods for detrend comparison " f"(need {min_pre_detrend}, have {max_pre})" ) except Exception as e: warnings.warn(f"Transformation comparison failed: {e}") # 4. Estimator comparison if 'estimator' in analyses and controls is not None: if verbose: print("4. Comparing estimators (RA, IPW, IPWRA)...") print() try: estimator_comparison = {} for est in ['ra', 'ipw', 'ipwra']: try: result_est = lwdid( data=data, y=y, d=d, ivar=ivar, tvar=tvar, post=post, gvar=gvar, rolling=rolling, estimator=est, controls=controls, vce=vce, cluster_var=cluster_var, alpha=alpha, ) estimator_comparison[est] = result_est.att except Exception: pass if len(estimator_comparison) >= 2: atts = list(estimator_comparison.values()) estimator_comparison['range'] = max(atts) - min(atts) # Check if range is substantial baseline_att = estimator_comparison.get('ra', atts[0]) if abs(baseline_att) > 1e-10: rel_range = estimator_comparison['range'] / abs(baseline_att) if rel_range > 0.25: recommendations.append( f"Substantial variation across estimators ({rel_range:.1%}). " "Consider which estimator assumptions are most appropriate." ) except Exception as e: warnings.warn(f"Estimator comparison failed: {e}") # Generate overall assessment issues = [] if pre_period_result and not pre_period_result.is_robust: issues.append("pre-period sensitivity") if anticipation_result and anticipation_result.anticipation_detected: issues.append("anticipation effects") if transformation_comparison and transformation_comparison.get('difference', 0) > 0.25 * abs(transformation_comparison.get('demean_att', 1)): issues.append("transformation sensitivity") if not issues: overall_assessment = "Results appear robust across multiple sensitivity checks." elif len(issues) == 1: overall_assessment = f"Caution: {issues[0]} detected. See recommendations." else: overall_assessment = f"Multiple concerns: {', '.join(issues)}. Interpret with caution." if not recommendations: recommendations.append("No major robustness concerns identified.") result = ComprehensiveSensitivityResult( pre_period_result=pre_period_result, anticipation_result=anticipation_result, transformation_comparison=transformation_comparison, estimator_comparison=estimator_comparison, overall_assessment=overall_assessment, recommendations=recommendations, ) if verbose: print() print(result.summary()) return result
# ============================================================================= # Convenience function for plotting # =============================================================================
[docs] def plot_sensitivity( result: PrePeriodRobustnessResult | NoAnticipationSensitivityResult, show_ci: bool = True, show_baseline: bool = True, highlight_significant: bool = True, figsize: tuple[float, float] = (10, 6), ax: Any = None, ) -> Any: """ Visualize sensitivity analysis results. Creates a plot showing how ATT estimates vary across different specifications, with confidence intervals and significance indicators. Parameters ---------- result : PrePeriodRobustnessResult or NoAnticipationSensitivityResult Result object from sensitivity analysis. show_ci : bool, default True Whether to show confidence intervals. show_baseline : bool, default True Whether to show baseline reference line. highlight_significant : bool, default True Whether to highlight significant estimates. figsize : tuple, default (10, 6) Figure size in inches. ax : matplotlib.axes.Axes, optional Axes to plot on. Returns ------- matplotlib.figure.Figure The generated figure. """ if isinstance(result, PrePeriodRobustnessResult): return result.plot(show_ci=show_ci, show_baseline=show_baseline, figsize=figsize, ax=ax) elif isinstance(result, NoAnticipationSensitivityResult): return result.plot(show_ci=show_ci, figsize=figsize, ax=ax) else: raise TypeError(f"Unsupported result type: {type(result)}")