"""
Control group selection for staggered difference-in-differences.
This module provides functions for identifying and validating control
groups in staggered adoption designs. Three control group strategies
are supported:
- **Never-Treated (NT)**: Units that never receive treatment throughout
the observation period. These form a stable comparison group whose
composition does not change across calendar time periods.
- **Not-Yet-Treated (NYT)**: Units scheduled for future treatment but
not yet treated at the current period. Under conditional parallel
trends and no anticipation, these units provide valid controls before
their own treatment begins, expanding the control pool for improved
estimation efficiency.
- **All Others**: All units not in the focal treatment cohort, including
already-treated units. This strategy may introduce forbidden
comparisons and is generally not recommended for identification, but
is provided for replication and diagnostic purposes.
For estimating the ATT of cohort g in period r, valid controls are
units with first treatment period strictly greater than r (gvar > r),
which includes both never-treated units and units first treated after
period r.
Notes
-----
The strict inequality criterion (gvar > r rather than gvar >= r) is
fundamental to correct identification. Units beginning treatment in
period r (i.e., gvar == r) belong to the treatment group for that
period, not the control group. This ensures that comparisons are made
only against units that remain untreated throughout period r.
"""
from __future__ import annotations
from enum import Enum
import numpy as np
import pandas as pd
[docs]
class ControlGroupStrategy(Enum):
"""
Enumeration of control group selection strategies.
Defines which units are eligible to serve as controls when
estimating treatment effects for a given cohort-period pair.
The choice of strategy affects both the valid comparison group
and the identifying assumptions required.
Attributes
----------
NEVER_TREATED : str
Use only units that never receive treatment throughout the
observation window. Provides a stable control group whose
composition does not vary across periods. Required for
aggregated effect estimation where controls must be consistent.
NOT_YET_TREATED : str
Use never-treated units plus units not yet treated at the
current period. Expands the control pool by including future
treatment cohorts as temporary controls, improving estimation
efficiency under conditional parallel trends and no anticipation.
ALL_OTHERS : str
Use all units not in the focal treatment cohort, including
already-treated units from earlier cohorts. May induce forbidden
comparisons that violate identification assumptions. Provided
primarily for replication and diagnostic purposes.
AUTO : str
Automatically select based on data availability. Prefers the
not-yet-treated strategy when sufficient controls are available,
falling back to never-treated only when necessary.
See Also
--------
get_valid_control_units : Apply strategy to select control units.
count_control_units_by_strategy : Compare control counts across strategies.
"""
NEVER_TREATED = 'never_treated'
NOT_YET_TREATED = 'not_yet_treated'
ALL_OTHERS = 'all_others'
AUTO = 'auto'
[docs]
def identify_never_treated_units(
data: pd.DataFrame,
gvar: str,
ivar: str,
never_treated_values: list | None = None,
) -> pd.Series:
"""
Identify units that never receive treatment.
Creates a boolean mask indicating which units are classified as
never-treated based on their treatment timing variable values.
Parameters
----------
data : pd.DataFrame
Panel dataset containing unit and treatment timing information.
gvar : str
Column name indicating first treatment period for each unit.
ivar : str
Column name containing unit identifiers.
never_treated_values : list, optional
Values in gvar indicating never-treated status. Defaults to
[0, np.inf]. Units with NaN in gvar are also classified as
never-treated regardless of this parameter.
Returns
-------
pd.Series
Boolean Series indexed by ivar (unit ID). True indicates
never-treated status.
Raises
------
ValueError
If the input data is empty.
KeyError
If gvar or ivar column is not found in the data.
See Also
--------
has_never_treated_units : Check presence of never-treated units.
get_valid_control_units : Select control units for estimation.
Notes
-----
Never-treated units are identified through three mechanisms:
1. Missing values (NaN) in gvar, representing units with no
recorded treatment date.
2. Zero values, a common coding convention for never-treated.
3. Infinity values, representing treatment dates beyond the
observation window.
"""
if len(data) == 0:
raise ValueError("Input data is empty")
if gvar not in data.columns:
raise KeyError(f"Column '{gvar}' not found in data")
if ivar not in data.columns:
raise KeyError(f"Column '{ivar}' not found in data")
# Extract first gvar value per unit; panel data may have repeated rows
unit_gvar = data.groupby(ivar)[gvar].first()
# Default sentinel values: 0 and infinity are common conventions
if never_treated_values is None:
nt_values = [0, np.inf]
else:
nt_values = never_treated_values
# NaN always indicates never-treated (missing treatment date)
never_treated_mask = unit_gvar.isna()
# Include units matching any specified sentinel value
if len(nt_values) > 0:
never_treated_mask = never_treated_mask | unit_gvar.isin(nt_values)
return never_treated_mask
[docs]
def has_never_treated_units(
data: pd.DataFrame,
gvar: str,
ivar: str,
never_treated_values: list | None = None,
) -> bool:
"""
Check whether the data contains any never-treated units.
A convenience function for quickly determining if a never-treated
control group is available for estimation. This is particularly
useful for deciding whether aggregated effects can be estimated.
Parameters
----------
data : pd.DataFrame
Panel dataset containing unit and treatment timing information.
gvar : str
Column name indicating first treatment period for each unit.
ivar : str
Column name containing unit identifiers.
never_treated_values : list, optional
Values in gvar indicating never-treated status. Defaults to
[0, np.inf].
Returns
-------
bool
True if at least one never-treated unit exists.
See Also
--------
identify_never_treated_units : Get full mask of never-treated units.
validate_control_group : Validate control group for aggregation.
"""
nt_mask = identify_never_treated_units(data, gvar, ivar, never_treated_values)
return nt_mask.sum() > 0
[docs]
def get_valid_control_units(
data: pd.DataFrame,
gvar: str,
ivar: str,
cohort: int | float,
period: int | float,
strategy: ControlGroupStrategy = ControlGroupStrategy.NOT_YET_TREATED,
never_treated_values: list | None = None,
is_pre_treatment: bool = False,
) -> pd.Series:
"""
Determine valid control units for a specific cohort-period pair.
For estimating the ATT of cohort g in period r, identifies which
units can serve as valid controls based on the selected strategy
and the fundamental strict inequality criterion.
Parameters
----------
data : pd.DataFrame
Panel dataset containing unit and treatment timing information.
gvar : str
Column name indicating first treatment period for each unit.
ivar : str
Column name containing unit identifiers.
cohort : int or float
Treatment cohort (first treatment period g) of the treated group.
period : int or float
Calendar time period r for which to identify controls.
For post-treatment: must satisfy period >= cohort.
For pre-treatment: must satisfy period < cohort.
strategy : ControlGroupStrategy, default NOT_YET_TREATED
Strategy for selecting control units.
never_treated_values : list, optional
Values in gvar indicating never-treated status. Defaults to
[0, np.inf].
is_pre_treatment : bool, default False
If True, selects control units for pre-treatment period estimation
(parallel trends testing). For pre-treatment periods t < g, the
control group includes all units not yet treated at period t.
Returns
-------
pd.Series
Boolean Series indexed by unit ID where True indicates valid control.
Raises
------
ValueError
If data is empty, or period constraints are violated.
KeyError
If required columns are not found.
TypeError
If gvar column is not numeric.
See Also
--------
get_all_control_masks : Batch computation for multiple cohort-period pairs.
get_all_control_masks_pre : Batch computation for pre-treatment periods.
validate_control_group : Validate control group size requirements.
Notes
-----
The strict inequality criterion (gvar > period) is fundamental:
- Units with gvar == period are beginning treatment in period r and
thus belong to the treatment group, not the control group.
- This ensures valid controls have not yet been exposed to treatment.
For post-treatment estimation (period >= cohort):
The treatment cohort is automatically excluded because cohort
units have gvar == cohort <= period, failing the gvar > period
criterion.
For pre-treatment estimation (period < cohort):
The treatment cohort is correctly included as controls because
period < cohort implies gvar (== cohort) > period. At pre-treatment
periods, these units are not yet treated and serve as valid
comparisons for parallel trends assessment.
"""
# -------------------------------------------------------------------------
# Input Validation
# -------------------------------------------------------------------------
if len(data) == 0:
raise ValueError("Input data is empty")
if gvar not in data.columns:
raise KeyError(f"Column '{gvar}' not found in data")
if ivar not in data.columns:
raise KeyError(f"Column '{ivar}' not found in data")
# Require numeric gvar for comparison operations
if not pd.api.types.is_numeric_dtype(data[gvar]):
raise TypeError(
f"gvar column '{gvar}' must be numeric, got {data[gvar].dtype}. "
f"String values like 'never' or '2005' are not supported."
)
# Convert to float for consistent comparison across int/float inputs
cohort_f = float(cohort)
period_f = float(period)
# Validate period constraints based on pre/post treatment context
if is_pre_treatment:
if period_f >= cohort_f:
raise ValueError(
f"For pre-treatment estimation, period ({period}) must be < cohort ({cohort}). "
f"Pre-treatment effects are only defined for periods t < g."
)
else:
if period_f < cohort_f:
raise ValueError(
f"period ({period}) must be >= cohort ({cohort}). "
f"Treatment effects are only defined for periods r >= g."
)
# -------------------------------------------------------------------------
# Identify Never-Treated Units
# -------------------------------------------------------------------------
# Extract first gvar value per unit from panel data
unit_gvar = data.groupby(ivar)[gvar].first()
# Default sentinel values for never-treated status
if never_treated_values is None:
nt_values = [0, np.inf]
else:
nt_values = never_treated_values
# Build never-treated mask from NaN and sentinel values
never_treated_mask = unit_gvar.isna()
if len(nt_values) > 0:
never_treated_mask = never_treated_mask | unit_gvar.isin(nt_values)
# -------------------------------------------------------------------------
# Build Control Mask by Strategy
# -------------------------------------------------------------------------
if strategy == ControlGroupStrategy.NEVER_TREATED:
# Use only units that never receive treatment
control_mask = never_treated_mask
elif strategy == ControlGroupStrategy.NOT_YET_TREATED:
# Include never-treated plus units first treated after current period
# Strict inequality: gvar > period excludes units starting treatment now
not_yet_treated_mask = (unit_gvar > period_f)
control_mask = never_treated_mask | not_yet_treated_mask
elif strategy == ControlGroupStrategy.ALL_OTHERS:
# Include all units except the focal treatment cohort
# Warning: may include already-treated units from earlier cohorts
control_mask = (unit_gvar != cohort_f)
else: # AUTO strategy
# Prefer not-yet-treated if available, fallback to never-treated
not_yet_treated_mask = (unit_gvar > period_f)
nyt_plus_nt_mask = never_treated_mask | not_yet_treated_mask
if nyt_plus_nt_mask.sum() > 0:
control_mask = nyt_plus_nt_mask
else:
control_mask = never_treated_mask
return control_mask
[docs]
def get_all_control_masks(
data: pd.DataFrame,
gvar: str,
ivar: str,
cohorts: list[int | float],
T_max: int | float,
T_min: int | float | None = None,
strategy: ControlGroupStrategy = ControlGroupStrategy.NOT_YET_TREATED,
never_treated_values: list | None = None,
) -> dict[tuple[int | float, int | float], pd.Series]:
"""
Compute control group masks for all cohort-period combinations.
Efficiently generates control masks for multiple cohort-period pairs
by pre-computing shared data structures. This batch approach avoids
redundant groupby operations when estimating effects across many
(cohort, period) combinations.
Parameters
----------
data : pd.DataFrame
Panel dataset containing unit and treatment timing information.
gvar : str
Column name indicating first treatment period for each unit.
ivar : str
Column name containing unit identifiers.
cohorts : list of int or float
Treatment cohorts for which to generate control masks.
T_max : int or float
Maximum time period to consider (inclusive).
T_min : int or float, optional
Minimum time period. Reserved for future extension.
strategy : ControlGroupStrategy, default NOT_YET_TREATED
Strategy for selecting control units.
never_treated_values : list, optional
Values in gvar indicating never-treated status. Defaults to
[0, np.inf].
Returns
-------
dict
Dictionary mapping (cohort, period) tuples to boolean Series
indexed by unit ID. True indicates valid control status.
Raises
------
ValueError
If the input data is empty.
See Also
--------
get_valid_control_units : Single cohort-period control mask.
get_all_control_masks_pre : Batch computation for pre-treatment periods.
Notes
-----
For each cohort g, masks are generated for post-treatment periods
{g, g+1, ..., T_max}. The never-treated mask is computed once and
reused across all cohort-period pairs, while not-yet-treated masks
vary by period due to the strict inequality criterion.
"""
if len(data) == 0:
raise ValueError("Input data is empty")
# -------------------------------------------------------------------------
# Pre-compute Shared Data Structures
# -------------------------------------------------------------------------
# Extract unit-level gvar once to avoid repeated groupby operations
unit_gvar = data.groupby(ivar)[gvar].first()
# Build never-treated mask (constant across all periods)
if never_treated_values is None:
nt_values = [0, np.inf]
else:
nt_values = never_treated_values
never_treated_mask = unit_gvar.isna()
if len(nt_values) > 0:
never_treated_mask = never_treated_mask | unit_gvar.isin(nt_values)
T_max_f = float(T_max)
# -------------------------------------------------------------------------
# Generate Masks for Each (cohort, period) Pair
# -------------------------------------------------------------------------
results = {}
for g in cohorts:
g_f = float(g)
# Iterate over post-treatment periods: g, g+1, ..., T_max
r = g_f
while r <= T_max_f:
if strategy == ControlGroupStrategy.NEVER_TREATED:
control_mask = never_treated_mask.copy()
elif strategy == ControlGroupStrategy.NOT_YET_TREATED:
# Not-yet-treated: units with first treatment after period r
not_yet_treated_mask = (unit_gvar > r)
control_mask = never_treated_mask | not_yet_treated_mask
elif strategy == ControlGroupStrategy.ALL_OTHERS:
# All non-cohort units (may include already-treated)
control_mask = (unit_gvar != g_f)
else: # AUTO strategy
not_yet_treated_mask = (unit_gvar > r)
nyt_plus_nt_mask = never_treated_mask | not_yet_treated_mask
if nyt_plus_nt_mask.sum() > 0:
control_mask = nyt_plus_nt_mask
else:
control_mask = never_treated_mask.copy()
results[(g, r)] = control_mask
r += 1
return results
[docs]
def get_all_control_masks_pre(
data: pd.DataFrame,
gvar: str,
ivar: str,
cohorts: list[int | float],
T_min: int | float,
strategy: ControlGroupStrategy = ControlGroupStrategy.NOT_YET_TREATED,
never_treated_values: list | None = None,
) -> dict[tuple[int | float, int | float], pd.Series]:
"""
Compute control group masks for all pre-treatment cohort-period combinations.
Efficiently generates control masks for pre-treatment periods by
pre-computing shared data structures. Used for parallel trends
testing and event study visualization where pre-treatment effects
should be approximately zero under the identifying assumptions.
Parameters
----------
data : pd.DataFrame
Panel dataset containing unit and treatment timing information.
gvar : str
Column name indicating first treatment period for each unit.
ivar : str
Column name containing unit identifiers.
cohorts : list of int or float
Treatment cohorts for which to generate control masks.
T_min : int or float
Minimum time period in the data (inclusive).
strategy : ControlGroupStrategy, default NOT_YET_TREATED
Strategy for selecting control units.
never_treated_values : list, optional
Values in gvar indicating never-treated status. Defaults to
[0, np.inf].
Returns
-------
dict
Dictionary mapping (cohort, period) tuples to boolean Series
indexed by unit ID. True indicates valid control status.
Raises
------
ValueError
If the input data is empty.
See Also
--------
get_valid_control_units : Single cohort-period control mask.
get_all_control_masks : Batch computation for post-treatment periods.
Notes
-----
For each cohort g, masks are generated for pre-treatment periods
{T_min, T_min+1, ..., g-1}. At pre-treatment period t < g, the
focal treatment cohort (gvar == g) is correctly included as
controls because these units are not yet treated.
"""
if len(data) == 0:
raise ValueError("Input data is empty")
# -------------------------------------------------------------------------
# Pre-compute Shared Data Structures
# -------------------------------------------------------------------------
# Extract unit-level gvar once to avoid repeated groupby operations
unit_gvar = data.groupby(ivar)[gvar].first()
# Build never-treated mask (constant across all periods)
if never_treated_values is None:
nt_values = [0, np.inf]
else:
nt_values = never_treated_values
never_treated_mask = unit_gvar.isna()
if len(nt_values) > 0:
never_treated_mask = never_treated_mask | unit_gvar.isin(nt_values)
T_min_f = float(T_min)
# -------------------------------------------------------------------------
# Generate Masks for Each (cohort, period) Pair
# -------------------------------------------------------------------------
results = {}
for g in cohorts:
g_f = float(g)
# Iterate over pre-treatment periods: T_min, T_min+1, ..., g-1
t = T_min_f
while t < g_f:
if strategy == ControlGroupStrategy.NEVER_TREATED:
control_mask = never_treated_mask.copy()
elif strategy == ControlGroupStrategy.NOT_YET_TREATED:
# Include all units not yet treated at period t
not_yet_treated_mask = (unit_gvar > t)
control_mask = never_treated_mask | not_yet_treated_mask
elif strategy == ControlGroupStrategy.ALL_OTHERS:
# All non-cohort units (may include already-treated)
control_mask = (unit_gvar != g_f)
else: # AUTO strategy
not_yet_treated_mask = (unit_gvar > t)
nyt_plus_nt_mask = never_treated_mask | not_yet_treated_mask
if nyt_plus_nt_mask.sum() > 0:
control_mask = nyt_plus_nt_mask
else:
control_mask = never_treated_mask.copy()
results[(g, t)] = control_mask
t += 1
return results
[docs]
def validate_control_group(
control_mask: pd.Series,
cohort: int | float,
period: int | float,
min_control_units: int = 1,
aggregate_type: str | None = None,
has_never_treated: bool = True,
strategy: ControlGroupStrategy | None = None,
) -> tuple[bool, str]:
"""
Validate whether a control group meets estimation requirements.
Checks control group suitability for treatment effect estimation,
including minimum size requirements and aggregation constraints.
Parameters
----------
control_mask : pd.Series
Boolean Series indexed by unit ID indicating control group membership.
cohort : int or float
Treatment cohort being estimated.
period : int or float
Time period being estimated.
min_control_units : int, default 1
Minimum number of control units required for estimation.
aggregate_type : str, optional
Type of aggregation ('cohort' or 'overall'). Aggregated effects
require never-treated units because not-yet-treated controls
vary across periods and cannot form a consistent comparison group.
has_never_treated : bool, default True
Whether the data contains any never-treated units.
strategy : ControlGroupStrategy, optional
Control group strategy being used. Generates warnings when
aggregated estimation uses non-recommended strategies.
Returns
-------
is_valid : bool
True if the control group passes all validation checks.
message : str
Descriptive message indicating success or failure reason.
See Also
--------
get_valid_control_units : Generate control group masks.
has_never_treated_units : Check for never-treated unit availability.
Notes
-----
Validation checks are applied in priority order:
1. Non-empty control group (required for any estimation)
2. Minimum size requirement (ensures sufficient degrees of freedom)
3. Aggregation constraints: cohort-level and overall effects require
never-treated units because they aggregate across multiple periods,
and not-yet-treated units transition out of the control group as
they become treated
"""
n_controls = control_mask.sum()
# Check 1: Non-empty control group
if n_controls == 0:
return False, f"No control units found for cohort={cohort}, period={period}."
# Check 2: Minimum size requirement
if n_controls < min_control_units:
return False, (
f"Insufficient control units for cohort={cohort}, period={period}. "
f"Found {n_controls}, required {min_control_units}."
)
# Check 3: Aggregation constraints
if aggregate_type in ('cohort', 'overall'):
# Aggregated effects need consistent controls across periods
if not has_never_treated:
return False, (
f"Cannot estimate {aggregate_type} effects without never-treated units. "
f"When all units are eventually treated, only cohort-period-specific "
f"effects can be estimated using not-yet-treated controls."
)
# Recommend never-treated strategy for aggregated estimation
if strategy is not None and strategy != ControlGroupStrategy.NEVER_TREATED:
return True, (
f"Valid control group with {n_controls} units. "
f"Warning: For {aggregate_type} effect estimation, it is recommended "
f"to use 'never_treated' control group strategy for robustness."
)
return True, f"Valid control group with {n_controls} units for cohort={cohort}, period={period}."
[docs]
def count_control_units_by_strategy(
data: pd.DataFrame,
gvar: str,
ivar: str,
cohort: int | float,
period: int | float,
never_treated_values: list | None = None,
) -> dict[str, int]:
"""
Count available control units under different selection strategies.
A diagnostic function to help users understand data structure and
make informed decisions about control group selection.
Parameters
----------
data : pd.DataFrame
Panel dataset containing unit and treatment timing information.
gvar : str
Column name indicating first treatment period for each unit.
ivar : str
Column name containing unit identifiers.
cohort : int or float
Treatment cohort of interest.
period : int or float
Time period of interest.
never_treated_values : list, optional
Values in gvar indicating never-treated status. Defaults to
[0, np.inf].
Returns
-------
dict
Dictionary with keys:
- ``'never_treated'``: Count of never-treated units.
- ``'not_yet_treated_only'``: Count of units treated in future
periods (excluding never-treated).
- ``'not_yet_treated_total'``: Total valid controls under the
not-yet-treated strategy.
- ``'treatment_cohort'``: Count of units in the treatment cohort.
Raises
------
ValueError
If the input data is empty.
See Also
--------
get_valid_control_units : Generate control masks for estimation.
ControlGroupStrategy : Available control group selection strategies.
Notes
-----
The not-yet-treated count uses strict inequality (gvar > period) to
exclude units beginning treatment in the current period.
"""
if len(data) == 0:
raise ValueError("Input data is empty")
# Extract unit-level gvar
unit_gvar = data.groupby(ivar)[gvar].first()
# Handle never_treated_values
if never_treated_values is None:
nt_values = [0, np.inf]
else:
nt_values = never_treated_values
# Convert to float for consistent numeric comparison
cohort_f = float(cohort)
period_f = float(period)
# Identify never-treated units via NaN or sentinel values
nt_mask = unit_gvar.isna()
if len(nt_values) > 0:
nt_mask = nt_mask | unit_gvar.isin(nt_values)
# Not-yet-treated excludes both never-treated and currently-treated units
nyt_only_mask = (unit_gvar > period_f) & ~nt_mask
treat_mask = (unit_gvar == cohort_f)
return {
'never_treated': int(nt_mask.sum()),
'not_yet_treated_only': int(nyt_only_mask.sum()),
'not_yet_treated_total': int(nt_mask.sum() + nyt_only_mask.sum()),
'treatment_cohort': int(treat_mask.sum())
}