"""
Unit-specific panel data transformations for difference-in-differences.
This module implements unit-specific outcome transformations that remove
pre-treatment heterogeneity from panel data. Transformation parameters are
estimated using only pre-treatment observations, then applied out-of-sample
to all periods including post-treatment.
The transformations convert panel difference-in-differences estimation into
cross-sectional treatment effects problems. Under no anticipation and parallel
trends assumptions, standard treatment effect estimators (regression adjustment,
inverse probability weighting, doubly robust, matching) can be applied to the
transformed outcomes.
Available Transformations
-------------------------
demean
Removes unit-specific pre-treatment mean:
.. math::
\\dot{Y}_{it} = Y_{it} - \\bar{Y}_{i,pre}
where :math:`\\bar{Y}_{i,pre} = T_0^{-1} \\sum_{s<g} Y_{is}`.
Requires at least 1 pre-treatment period per unit.
detrend
Removes unit-specific linear time trend:
.. math::
\\dot{Y}_{it} = Y_{it} - \\hat{\\alpha}_i - \\hat{\\beta}_i t
where :math:`(\\hat{\\alpha}_i, \\hat{\\beta}_i)` are OLS estimates from
pre-treatment data. Requires at least 2 pre-treatment periods per unit.
demeanq
Removes unit-specific mean with quarterly seasonal fixed effects:
.. math::
\\dot{Y}_{it} = Y_{it} - \\hat{\\mu}_i - \\sum_{q=2}^{4} \\hat{\\gamma}_q D_q
where :math:`D_q` are quarter dummies with the smallest observed quarter
as reference category. Requires :math:`n_{pre} \\geq Q + 1` per unit.
detrendq
Removes unit-specific linear trend with quarterly seasonal effects:
.. math::
\\dot{Y}_{it} = Y_{it} - \\hat{\\alpha}_i - \\hat{\\beta}_i t
- \\sum_{q=2}^{4} \\hat{\\gamma}_q D_q
Requires :math:`n_{pre} \\geq Q + 2` per unit.
Notes
-----
The transformations eliminate unit-specific level differences or trends that
may be correlated with treatment assignment. By removing these pre-treatment
patterns, the parallel trends assumption becomes an assumption about the
transformed outcomes rather than the original levels.
Time centering is applied in detrending methods to improve numerical stability
of OLS estimation. This reduces the condition number of the design matrix
without affecting the final residuals, as centering is an affine transformation
that preserves predicted values.
"""
from __future__ import annotations
import warnings
import numpy as np
import pandas as pd
import statsmodels.api as sm
from .exceptions import InsufficientPrePeriodsError
from .validation import (
validate_quarter_diversity,
validate_quarter_coverage,
validate_season_diversity,
validate_season_coverage,
)
# Threshold for detecting degenerate time variance in OLS estimation.
# Time series with variance below this value lack sufficient variation
# for reliable slope estimation and will produce numerically unstable results.
VARIANCE_THRESHOLD = 1e-10
def _compute_max_pre_tindex(
data: pd.DataFrame, post: str, tindex: str, method: str
) -> int:
"""
Compute the maximum pre-treatment time index.
Identifies the last period before treatment onset, defined as
:math:`K = \\max\\{t : \\text{post}_t = 0\\}`.
Parameters
----------
data : pd.DataFrame
Panel data containing post indicator and time index columns.
post : str
Column name of binary post-treatment indicator (0=pre, 1=post).
tindex : str
Column name of integer-valued time index.
method : str
Transformation method name for error message context.
Returns
-------
int
Maximum pre-treatment period index K.
Raises
------
InsufficientPrePeriodsError
If no pre-treatment observations exist or all pre-treatment
time index values are missing.
"""
pre_tindex = data[data[post] == 0][tindex]
if pre_tindex.empty:
raise InsufficientPrePeriodsError(
f"No pre-treatment observations found (post==0). "
f"rolling('{method}') requires at least 1 pre-treatment period."
)
if pre_tindex.isna().all():
raise InsufficientPrePeriodsError(
f"All pre-treatment time index values are NaN. "
f"rolling('{method}') requires valid time index values."
)
return int(pre_tindex.max())
def _validate_seasonal_transform_requirements(
data: pd.DataFrame,
ivar: str,
tindex: str,
post: str,
season_var: str,
y: str,
transform_type: str,
min_global_pre_periods: int,
Q: int = 4,
) -> int:
"""
Validate data requirements for seasonal transformations.
Verifies that data satisfy the requirements for demeanq or detrendq
transformations: sufficient global pre-treatment periods, adequate
per-unit observations for model estimation, and complete seasonal
coverage in the pre-treatment period.
Parameters
----------
data : pd.DataFrame
Panel data with unit, time, post indicator, and seasonal columns.
ivar : str
Column name for unit identifier.
tindex : str
Column name for time index.
post : str
Column name for post-treatment indicator (0=pre, 1=post).
season_var : str
Column name for seasonal variable (values in {1, 2, ..., Q}).
y : str
Column name for outcome variable.
transform_type : str
Transformation type: 'demeanq' or 'detrendq'.
min_global_pre_periods : int
Minimum required unique pre-treatment periods.
Q : int, default 4
Number of seasonal periods per cycle. Common values:
- 4: Quarterly data (default)
- 12: Monthly data
- 52: Weekly data
Returns
-------
int
Maximum pre-treatment time index K.
Raises
------
InsufficientPrePeriodsError
If global pre-period count is below minimum, or if any unit has
insufficient pre-period observations for reliable estimation.
ValueError
If season_var column contains values outside {1, 2, ..., Q}.
Notes
-----
Parameter count differs by transformation type:
- demeanq: :math:`Y \\sim 1 + \\text{season}` requires :math:`k = Q` parameters
- detrendq: :math:`Y \\sim 1 + t + \\text{season}` requires :math:`k = Q + 1`
where Q is the number of distinct seasons in the unit's pre-period.
Reliable estimation requires :math:`n \\geq k + 1` to ensure at least
one residual degree of freedom.
"""
pre_data = data[data[post] == 0]
n_pre_periods = pre_data[tindex].nunique()
K = _compute_max_pre_tindex(data, post, tindex, transform_type)
# Validate season values are in the expected range {1, 2, ..., Q}.
# Non-standard values would create incorrect dummy variables in the model.
season_values = data[season_var].dropna().unique()
valid_seasons = set(range(1, Q + 1))
# Handle both integer and float representations (e.g., 1.0, 2.0).
try:
season_int_values = {int(s) for s in season_values if s == int(s)}
invalid_values = set(season_values) - {float(s) for s in season_int_values}
out_of_range = season_int_values - valid_seasons
except (ValueError, TypeError):
invalid_values = set(season_values)
out_of_range = set()
if invalid_values or out_of_range:
all_invalid = invalid_values | {float(s) for s in out_of_range}
freq_label = {4: 'quarters', 12: 'months', 52: 'weeks'}.get(Q, f'seasons (1-{Q})')
raise ValueError(
f"Seasonal column '{season_var}' contains invalid values: {sorted(all_invalid)}. "
f"Expected integer values in {{1, 2, ..., {Q}}} representing {freq_label}.\n\n"
f"If your seasonal values use a different encoding (e.g., 0-{Q-1}), "
f"please recode them to 1-{Q} before calling lwdid()."
)
if n_pre_periods < min_global_pre_periods:
raise InsufficientPrePeriodsError(
f"rolling('{transform_type}') requires at least {min_global_pre_periods} "
f"pre-treatment period(s). Found: {n_pre_periods} unique pre-treatment "
f"period(s) (max tindex={K})."
)
# Model parameter offset: demeanq has k=Q, detrendq has k=Q+1 (adds trend).
param_offset = 0 if transform_type == 'demeanq' else 1
for unit_id in data[ivar].unique():
unit_mask = (data[ivar] == unit_id)
unit_data = data[unit_mask]
unit_pre = unit_data[unit_data[post] == 0]
unit_pre_count = len(unit_pre)
if unit_pre_count < min_global_pre_periods:
raise InsufficientPrePeriodsError(
f"Unit {unit_id} has {'no' if unit_pre_count == 0 else f'only {unit_pre_count}'} "
f"pre-period observation(s). rolling('{transform_type}') requires at least "
f"{min_global_pre_periods} pre-treatment period(s) per unit."
)
# Build valid_mask to match OLS estimation (missing='drop').
if transform_type == 'demeanq':
valid_mask = unit_pre[y].notna() & unit_pre[season_var].notna()
else:
valid_mask = unit_pre[y].notna() & unit_pre[tindex].notna() & unit_pre[season_var].notna()
n_valid = valid_mask.sum()
n_unique_seasons = unit_pre.loc[valid_mask, season_var].nunique() if n_valid > 0 else 0
n_params = n_unique_seasons + param_offset
min_required = n_params + 1 # Require at least df = 1
if unit_pre_count < min_required:
freq_label = {4: 'quarter', 12: 'month', 52: 'week'}.get(Q, 'season')
model_desc = (
f"y ~ 1 + i.{freq_label} with k = {n_unique_seasons} parameters "
f"(1 constant + {n_unique_seasons-1} {freq_label} dummies)"
if transform_type == 'demeanq' else
f"y ~ 1 + tindex + i.{freq_label} with k = 1 + 1 + ({n_unique_seasons}-1) = {n_params} parameters"
)
raise InsufficientPrePeriodsError(
f"Unit {unit_id} has {unit_pre_count} pre-period observation(s) "
f"with {n_unique_seasons} distinct {freq_label}(s). "
f"rolling('{transform_type}') requires at least {min_required} observations "
f"to ensure df = n - k ≥ 1 for reliable statistical inference. "
f"The {transform_type} method estimates a model {model_desc}."
)
validate_season_coverage(data, ivar, season_var, post, Q)
return K
# Backward compatibility alias
def _validate_quarterly_transform_requirements(
data: pd.DataFrame,
ivar: str,
tindex: str,
post: str,
quarter: str,
y: str,
transform_type: str,
min_global_pre_periods: int,
) -> int:
"""Backward compatibility wrapper for _validate_seasonal_transform_requirements."""
return _validate_seasonal_transform_requirements(
data, ivar, tindex, post, quarter, y, transform_type, min_global_pre_periods, Q=4
)
[docs]
def detrend_unit(
unit_data: pd.DataFrame, y: str, tindex: str, post: str
) -> tuple[np.ndarray, np.ndarray]:
"""
Remove unit-specific linear time trend for a single unit.
Estimates :math:`Y_{it} = \\alpha + \\beta t + \\varepsilon` via OLS using
pre-treatment observations only, then computes out-of-sample residuals
for all periods:
.. math::
\\dot{Y}_{it} = Y_{it} - \\hat{\\alpha} - \\hat{\\beta} t
This transformation removes unit-specific linear trends that may violate
the parallel trends assumption in levels.
Parameters
----------
unit_data : pd.DataFrame
Data for a single unit containing all time periods.
y : str
Column name of outcome variable.
tindex : str
Column name of time index.
post : str
Column name of binary post-treatment indicator (0=pre, 1=post).
Returns
-------
yhat_all : ndarray
Fitted values :math:`\\hat{\\alpha} + \\hat{\\beta} t` for all periods.
Returns NaN array if estimation fails.
ydot : ndarray
Detrended outcomes for all periods. Returns NaN array if estimation
fails due to numerical issues.
Notes
-----
Time centering at the pre-treatment mean improves numerical stability by
reducing the condition number of :math:`X'X`. Centering is an affine
transformation that preserves predicted values.
See Also
--------
_detrend_transform : Apply detrending to all units in panel data.
"""
unit_pre = unit_data[unit_data[post] == 0].copy()
n_obs = len(unit_data)
# Slope estimation requires variation in time; identical values make X'X singular.
# Using ddof=0 (population variance) for numerical stability check rather than
# statistical inference. This provides a stricter check for small samples and
# directly measures the actual variation in the time index values.
t_variance = unit_pre[tindex].var(ddof=0)
if t_variance < VARIANCE_THRESHOLD:
warnings.warn(
"Degenerate time variance detected: all pre-treatment time values are "
"identical, making OLS slope estimation impossible. Returning NaN.",
UserWarning, stacklevel=2
)
return np.full(n_obs, np.nan), np.full(n_obs, np.nan)
# OLS requires n > k for positive residual degrees of freedom.
valid_mask = unit_pre[y].notna() & unit_pre[tindex].notna()
n_valid = valid_mask.sum()
n_params = 2 # intercept + slope
if n_valid <= n_params:
warnings.warn(
f"Insufficient valid pre-treatment observations for OLS detrending: "
f"found {n_valid} valid observations, require at least {n_params + 1} "
f"(more than {n_params} parameters to ensure df >= 1). Returning NaN.",
UserWarning, stacklevel=2
)
return np.full(n_obs, np.nan), np.full(n_obs, np.nan)
# Centering reduces condition number from O(t_max^2) to O(1).
t_mean = unit_pre[tindex].mean()
t_centered_pre = unit_pre[tindex] - t_mean
X_pre = sm.add_constant(t_centered_pre.values)
y_pre = unit_pre[y].values
model = sm.OLS(y_pre, X_pre, missing='drop').fit()
# Invalid coefficients indicate numerical failure (e.g., collinearity).
if np.isnan(model.params).any() or np.isinf(model.params).any():
warnings.warn(
f"OLS detrending produced invalid coefficients (NaN or Inf). "
f"This may indicate insufficient time variation or constant "
f"outcome values in pre-treatment data. Returning NaN for "
f"detrended values.",
UserWarning, stacklevel=3
)
return np.full(n_obs, np.nan), np.full(n_obs, np.nan)
# Out-of-sample prediction uses same centering; predicted values are affine-invariant.
t_centered_all = unit_data[tindex] - t_mean
X_all = sm.add_constant(t_centered_all.values)
yhat_all = model.predict(X_all)
ydot = unit_data[y].values - yhat_all
return yhat_all, ydot
[docs]
def demeanq_unit(
unit_data: pd.DataFrame,
y: str,
season_var: str,
post: str,
Q: int = 4,
) -> tuple[np.ndarray, np.ndarray]:
"""
Remove unit-specific mean with seasonal fixed effects.
Estimates a seasonal mean model using pre-treatment observations:
.. math::
Y_{it} = \\mu + \\sum_{q=2}^{Q} \\gamma_q D_q + \\varepsilon_{it}
where :math:`D_q` are seasonal dummies. The smallest observed season serves
as the reference category for identification.
Parameters
----------
unit_data : pd.DataFrame
Data for a single unit containing all time periods.
y : str
Column name of outcome variable.
season_var : str
Column name of seasonal indicator variable. Values should be integers
from 1 to Q representing seasonal periods (e.g., quarters 1-4, months
1-12, or weeks 1-52).
post : str
Column name of binary post-treatment indicator (0=pre, 1=post).
Q : int, default 4
Number of seasonal periods per cycle. Common values:
- 4: Quarterly data (default)
- 12: Monthly data
- 52: Weekly data
Returns
-------
yhat_all : ndarray
Fitted values :math:`\\hat{\\mu} + \\sum_q \\hat{\\gamma}_q D_q` for
all periods. Returns NaN array if estimation fails.
ydot : ndarray
Seasonally-adjusted demeaned outcomes for all periods. Returns NaN
array if estimation fails due to numerical issues.
Notes
-----
Using observed seasons rather than all Q seasons as categories
prevents rank-deficient design matrices when some seasons are absent
from pre-treatment data.
The minimum required pre-treatment observations is Q + 1 to ensure
at least one residual degree of freedom for OLS estimation.
See Also
--------
detrendq_unit : Combines seasonal adjustment with linear trend removal.
"""
unit_pre = unit_data[unit_data[post] == 0].copy()
n_obs = len(unit_data)
# OLS requires n > k for positive residual degrees of freedom.
valid_mask = unit_pre[y].notna() & unit_pre[season_var].notna()
n_valid = valid_mask.sum()
n_seasons = unit_pre.loc[valid_mask, season_var].nunique() if n_valid > 0 else 0
n_params = n_seasons # intercept + (n_seasons - 1) dummies
if n_valid <= n_params:
warnings.warn(
f"Insufficient valid pre-treatment observations for OLS seasonal demeaning: "
f"found {n_valid} valid observations with {n_seasons} distinct season(s), "
f"require at least {n_params + 1} (more than {n_params} parameters). "
f"Returning NaN.",
UserWarning, stacklevel=2
)
return np.full(n_obs, np.nan), np.full(n_obs, np.nan)
# Restrict dummy coding to observed seasons to avoid rank deficiency.
valid_pre = unit_pre[valid_mask]
observed_seasons_pre = sorted(valid_pre[season_var].unique())
s_categorical = pd.Categorical(unit_pre[season_var], categories=observed_seasons_pre)
s_dummies_pre = pd.get_dummies(s_categorical, drop_first=True, prefix='s', dtype=float)
X_pre = sm.add_constant(s_dummies_pre.values)
y_pre = unit_pre[y].values
model = sm.OLS(y_pre, X_pre, missing='drop').fit()
# Invalid coefficients indicate numerical failure (e.g., collinearity).
if np.isnan(model.params).any() or np.isinf(model.params).any():
warnings.warn(
f"OLS seasonal demeaning produced invalid coefficients (NaN or Inf). "
f"This may indicate constant outcome values or collinearity in "
f"pre-treatment data. Returning NaN for transformed values.",
UserWarning, stacklevel=3
)
return np.full(n_obs, np.nan), np.full(n_obs, np.nan)
# Prediction design matrix must match estimation design matrix structure.
s_categorical_all = pd.Categorical(unit_data[season_var], categories=observed_seasons_pre)
s_dummies_all = pd.get_dummies(s_categorical_all, drop_first=True, prefix='s', dtype=float)
# Handle season mismatch between pre and post periods:
# - Post-period new seasons get zero coefficients (extrapolation from model).
# - Post-period missing seasons require column alignment.
missing_cols = set(s_dummies_pre.columns) - set(s_dummies_all.columns)
for col in missing_cols:
s_dummies_all[col] = 0.0
extra_cols = set(s_dummies_all.columns) - set(s_dummies_pre.columns)
if extra_cols:
s_dummies_all = s_dummies_all.drop(columns=list(extra_cols))
s_dummies_all = s_dummies_all[s_dummies_pre.columns]
X_all = sm.add_constant(s_dummies_all.values)
yhat_all = model.predict(X_all)
ydot = unit_data[y].values - yhat_all
return yhat_all, ydot
[docs]
def detrendq_unit(
unit_data: pd.DataFrame,
y: str,
tindex: str,
season_var: str,
post: str,
Q: int = 4,
) -> tuple[np.ndarray, np.ndarray]:
"""
Remove unit-specific linear trend with seasonal fixed effects.
Estimates a combined trend and seasonal model using pre-treatment data:
.. math::
Y_{it} = \\alpha + \\beta t + \\sum_{q=2}^{Q} \\gamma_q D_q + \\varepsilon_{it}
The smallest observed season serves as the reference category. Time is
centered at its pre-treatment mean for numerical stability.
Parameters
----------
unit_data : pd.DataFrame
Data for a single unit containing all time periods.
y : str
Column name of outcome variable.
tindex : str
Column name of time index.
season_var : str
Column name of seasonal indicator variable. Values should be integers
from 1 to Q representing seasonal periods (e.g., quarters 1-4, months
1-12, or weeks 1-52).
post : str
Column name of binary post-treatment indicator (0=pre, 1=post).
Q : int, default 4
Number of seasonal periods per cycle. Common values:
- 4: Quarterly data (default)
- 12: Monthly data
- 52: Weekly data
Returns
-------
yhat_all : ndarray
Fitted values for all periods. Returns NaN array if estimation fails.
ydot : ndarray
Seasonally-adjusted detrended outcomes for all periods. Returns NaN
array if estimation fails due to numerical issues.
Notes
-----
This transformation combines trend removal and seasonal adjustment,
accounting for both unit-specific growth patterns and seasonal cycles.
Time centering reduces the condition number of the design matrix without
affecting predicted values.
The minimum required pre-treatment observations is Q + 2 to ensure
at least one residual degree of freedom for OLS estimation (intercept +
slope + Q-1 seasonal dummies = Q+1 parameters).
See Also
--------
demeanq_unit : Seasonal adjustment without trend removal.
detrend_unit : Linear trend removal without seasonal adjustment.
"""
unit_pre = unit_data[unit_data[post] == 0].copy()
n_obs = len(unit_data)
# Slope estimation requires variation in time; identical values make X'X singular.
# Using ddof=0 (population variance) for numerical stability check rather than
# statistical inference. This provides a stricter check for small samples and
# directly measures the actual variation in the time index values.
t_variance = unit_pre[tindex].var(ddof=0)
if t_variance < VARIANCE_THRESHOLD:
warnings.warn(
"Degenerate time variance detected: all pre-treatment time values are "
"identical, making OLS slope estimation impossible. Returning NaN.",
UserWarning, stacklevel=2
)
return np.full(n_obs, np.nan), np.full(n_obs, np.nan)
# OLS requires n > k for positive residual degrees of freedom.
valid_mask = unit_pre[y].notna() & unit_pre[tindex].notna() & unit_pre[season_var].notna()
n_valid = valid_mask.sum()
n_seasons = unit_pre.loc[valid_mask, season_var].nunique() if n_valid > 0 else 0
n_params = 1 + n_seasons # intercept + slope + (n_seasons - 1) dummies
if n_valid <= n_params:
warnings.warn(
f"Insufficient valid pre-treatment observations for OLS seasonal detrending: "
f"found {n_valid} valid observations with {n_seasons} distinct season(s), "
f"require at least {n_params + 1} (more than {n_params} parameters to ensure "
f"df >= 1). Returning NaN.",
UserWarning, stacklevel=2
)
return np.full(n_obs, np.nan), np.full(n_obs, np.nan)
# Center time at pre-treatment mean to improve numerical stability.
# Reduces condition number of X'X from O(t_max^2) to O(1).
t_mean = unit_pre[tindex].mean()
t_centered_pre = unit_pre[tindex] - t_mean
# Restrict dummy coding to observed seasons to avoid rank deficiency.
# Unobserved seasons would create all-zero columns in the design matrix.
valid_pre = unit_pre[valid_mask]
observed_seasons_pre = sorted(valid_pre[season_var].unique())
s_categorical = pd.Categorical(unit_pre[season_var], categories=observed_seasons_pre)
s_dummies_pre = pd.get_dummies(s_categorical, drop_first=True, prefix='s', dtype=float)
X_pre = np.column_stack([
np.ones(len(unit_pre)),
t_centered_pre.values,
s_dummies_pre.values
])
y_pre = unit_pre[y].values
model = sm.OLS(y_pre, X_pre, missing='drop').fit()
# Invalid coefficients indicate numerical failure (e.g., collinearity).
if np.isnan(model.params).any() or np.isinf(model.params).any():
warnings.warn(
f"OLS seasonal detrending produced invalid coefficients (NaN or Inf). "
f"This may indicate insufficient time variation, constant outcome values, "
f"or collinearity in pre-treatment data. Returning NaN for transformed "
f"values.",
UserWarning, stacklevel=3
)
return np.full(n_obs, np.nan), np.full(n_obs, np.nan)
# Out-of-sample prediction uses same centering; predicted values are affine-invariant.
t_centered_all = unit_data[tindex] - t_mean
# Prediction design matrix must match estimation design matrix structure.
s_categorical_all = pd.Categorical(unit_data[season_var], categories=observed_seasons_pre)
s_dummies_all = pd.get_dummies(s_categorical_all, drop_first=True, prefix='s', dtype=float)
# Handle season mismatch between pre and post periods:
# - Post-period new seasons get zero coefficients (extrapolation from model).
# - Post-period missing seasons require column alignment.
missing_cols = set(s_dummies_pre.columns) - set(s_dummies_all.columns)
for col in missing_cols:
s_dummies_all[col] = 0.0
extra_cols = set(s_dummies_all.columns) - set(s_dummies_pre.columns)
if extra_cols:
s_dummies_all = s_dummies_all.drop(columns=list(extra_cols))
s_dummies_all = s_dummies_all[s_dummies_pre.columns]
X_all = np.column_stack([
np.ones(len(unit_data)),
t_centered_all.values,
s_dummies_all.values
])
yhat_all = model.predict(X_all)
ydot = unit_data[y].values - yhat_all
return yhat_all, ydot