Source code for lwdid.clustering_diagnostics

"""
Clustering diagnostics and recommendations for difference-in-differences.

This module provides tools for analyzing clustering structure in panel data
and recommending appropriate clustering levels for standard error estimation
in difference-in-differences analysis.

When the policy or treatment varies at a level higher than the unit of
observation, standard errors should be clustered at the policy variation
level. This module helps identify the appropriate clustering level by:

- Analyzing hierarchical relationships between potential clustering variables
- Detecting the level at which treatment assignment varies
- Recommending clustering variables with sufficient cluster counts
- Checking consistency between clustering choice and treatment variation

For reliable cluster-robust inference, a minimum of 20-30 clusters is
generally recommended. When clusters are fewer, wild cluster bootstrap
methods provide more accurate inference.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from scipy import stats


# =============================================================================
# Enumeration Types
# =============================================================================

[docs] class ClusteringLevel(Enum): """ Relative level of clustering variable to unit variable. Attributes ---------- LOWER : str Cluster variable is at a lower level than unit (invalid for clustering). Example: sub-unit ID when unit is individual. SAME : str Cluster variable is at the same level as unit. Example: individual ID when unit is individual. HIGHER : str Cluster variable is at a higher level than unit (recommended). Example: state when unit is county. """ LOWER = "lower" SAME = "same" HIGHER = "higher"
[docs] class ClusteringWarningLevel(Enum): """ Severity level for clustering warnings. Attributes ---------- INFO : str Informational message, no action required. WARNING : str Warning that may affect inference reliability. ERROR : str Critical issue that prevents valid inference. """ INFO = "info" WARNING = "warning" ERROR = "error"
# ============================================================================= # Data Classes # =============================================================================
[docs] @dataclass class ClusterVarStats: """ Statistics for a single potential clustering variable. This class holds comprehensive statistics about a clustering variable, including cluster counts, size distributions, and validity indicators. Attributes ---------- var_name : str Name of the clustering variable. n_clusters : int Total number of unique clusters. n_treated_clusters : int Number of clusters containing treated units. n_control_clusters : int Number of clusters containing only control units. min_cluster_size : int Minimum number of observations in any cluster. max_cluster_size : int Maximum number of observations in any cluster. mean_cluster_size : float Mean cluster size. median_cluster_size : float Median cluster size. cluster_size_cv : float Coefficient of variation of cluster sizes (std/mean). level_relative_to_unit : ClusteringLevel Whether cluster is at higher/same/lower level than unit. units_per_cluster : float Average number of unique units per cluster. is_nested_in_unit : bool True if cluster varies within unit (invalid for clustering). treatment_varies_within_cluster : bool True if treatment status varies within clusters. n_clusters_with_treatment_variation : int Number of clusters with within-cluster treatment variation. Properties ---------- is_valid_cluster : bool Whether this is a valid clustering variable. is_recommended : bool Whether this clustering level is recommended. reliability_score : float Score indicating reliability of cluster-robust inference (0-1). """ var_name: str n_clusters: int n_treated_clusters: int n_control_clusters: int min_cluster_size: int max_cluster_size: int mean_cluster_size: float median_cluster_size: float cluster_size_cv: float level_relative_to_unit: ClusteringLevel units_per_cluster: float is_nested_in_unit: bool treatment_varies_within_cluster: bool n_clusters_with_treatment_variation: int = 0 @property def is_valid_cluster(self) -> bool: """ Whether this is a valid clustering variable. A valid clustering variable must: 1. Not be nested within units (each unit belongs to one cluster) 2. Have at least 2 clusters 3. Not be at a lower level than the unit variable Returns ------- bool True if valid for clustering. """ return ( not self.is_nested_in_unit and self.n_clusters >= 2 and self.level_relative_to_unit != ClusteringLevel.LOWER ) @property def is_recommended(self) -> bool: """ Whether this clustering level is recommended. A recommended clustering variable must: 1. Be valid (see is_valid_cluster) 2. Have at least 20 clusters for reliable inference 3. Treatment should not vary within clusters Returns ------- bool True if recommended for clustering. """ return ( self.is_valid_cluster and self.n_clusters >= 20 and not self.treatment_varies_within_cluster ) @property def reliability_score(self) -> float: """ Score indicating reliability of cluster-robust inference (0-1). Based on: - Number of clusters (more is better, saturates at 50) - Balance of treated/control clusters - Cluster size variation (less is better) Returns ------- float Reliability score between 0 and 1. """ # Cluster count score (0-1, saturates at 50 clusters). cluster_score = min(self.n_clusters / 50, 1.0) # Balance score (0-1). if self.n_clusters > 0: balance = min(self.n_treated_clusters, self.n_control_clusters) / (self.n_clusters / 2) balance_score = min(balance, 1.0) else: balance_score = 0.0 # Size variation score (0-1, lower CV is better). cv_score = max(0, 1 - self.cluster_size_cv / 2) # Weighted average. return 0.5 * cluster_score + 0.3 * balance_score + 0.2 * cv_score
[docs] @dataclass class ClusteringDiagnostics: """ Diagnostic results for clustering structure analysis. This class contains the complete results of clustering diagnostics, including statistics for each potential clustering variable and recommendations. Attributes ---------- cluster_structure : Dict[str, ClusterVarStats] Statistics for each potential clustering variable. recommended_cluster_var : Optional[str] Recommended clustering variable name. recommendation_reason : str Explanation for the recommendation. treatment_variation_level : str Detected level at which treatment varies. warnings : List[str] Warning messages about clustering choices. """ cluster_structure: Dict[str, ClusterVarStats] recommended_cluster_var: Optional[str] recommendation_reason: str treatment_variation_level: str warnings: List[str] = field(default_factory=list)
[docs] def summary(self) -> str: """ Generate human-readable summary of diagnostics. Returns ------- str Formatted summary string. """ lines = [ "=" * 70, "CLUSTERING DIAGNOSTICS", "=" * 70, "", f"Treatment Variation Level: {self.treatment_variation_level}", "", "Potential Clustering Variables:", "-" * 60, f"{'Variable':>15} {'N Clusters':>12} {'Treated':>10} {'Control':>10} {'Valid':>8}", "-" * 60, ] for var_name, stats in self.cluster_structure.items(): valid = "✓" if stats.is_valid_cluster else "✗" rec = " (rec)" if var_name == self.recommended_cluster_var else "" lines.append( f"{var_name:>15} {stats.n_clusters:>12} " f"{stats.n_treated_clusters:>10} {stats.n_control_clusters:>10} " f"{valid:>8}{rec}" ) lines.extend([ "", "─" * 70, f"RECOMMENDATION: cluster_var='{self.recommended_cluster_var}'", f"Reason: {self.recommendation_reason}", "─" * 70, ]) if self.warnings: lines.extend(["", "WARNINGS:"]) for w in self.warnings: lines.append(f" ⚠ {w}") lines.append("=" * 70) return "\n".join(lines)
[docs] @dataclass class ClusteringRecommendation: """ Recommendation for clustering level selection. This class provides a detailed recommendation for which clustering variable to use, along with confidence scores and alternatives. Attributes ---------- recommended_var : str Recommended clustering variable name. n_clusters : int Number of clusters with recommended variable. n_treated_clusters : int Number of treated clusters. n_control_clusters : int Number of control clusters. confidence : float Confidence in recommendation (0-1). reasons : List[str] List of reasons supporting the recommendation. alternatives : List[Dict[str, Any]] Alternative clustering options with their statistics. warnings : List[str] Warning messages. use_wild_bootstrap : bool Whether to recommend wild cluster bootstrap. wild_bootstrap_reason : Optional[str] Reason for wild bootstrap recommendation. """ recommended_var: str n_clusters: int n_treated_clusters: int n_control_clusters: int confidence: float reasons: List[str] alternatives: List[Dict[str, Any]] = field(default_factory=list) warnings: List[str] = field(default_factory=list) use_wild_bootstrap: bool = False wild_bootstrap_reason: Optional[str] = None
[docs] def summary(self) -> str: """ Generate human-readable summary of recommendation. Returns ------- str Formatted summary string. """ lines = [ "=" * 70, "CLUSTERING LEVEL RECOMMENDATION", "=" * 70, "", f"Recommended: cluster_var='{self.recommended_var}'", f" - Total clusters: {self.n_clusters}", f" - Treated clusters: {self.n_treated_clusters}", f" - Control clusters: {self.n_control_clusters}", f" - Confidence: {self.confidence:.1%}", "", "Reasons:", ] for i, reason in enumerate(self.reasons, 1): lines.append(f" {i}. {reason}") if self.use_wild_bootstrap: lines.extend([ "", "⚠ WILD CLUSTER BOOTSTRAP RECOMMENDED", f" Reason: {self.wild_bootstrap_reason}", ]) if self.alternatives: lines.extend(["", "Alternatives:"]) for alt in self.alternatives: lines.append( f" - {alt['var']}: {alt['n_clusters']} clusters " f"({alt['reason']})" ) if self.warnings: lines.extend(["", "WARNINGS:"]) for w in self.warnings: lines.append(f" ⚠ {w}") lines.append("=" * 70) return "\n".join(lines)
[docs] @dataclass class ClusteringConsistencyResult: """ Result of clustering consistency check. This class contains the results of checking whether the chosen clustering level is consistent with the treatment variation level. Attributes ---------- is_consistent : bool Whether clustering level is consistent with treatment variation. treatment_variation_level : str Detected level at which treatment varies. cluster_level : str Level of the clustering variable. n_clusters : int Number of clusters. n_treatment_changes_within_cluster : int Number of clusters with treatment variation within. pct_clusters_with_variation : float Percentage of clusters with within-cluster treatment variation. recommendation : str Suggested action if inconsistent. details : str Detailed explanation of the consistency check. """ is_consistent: bool treatment_variation_level: str cluster_level: str n_clusters: int n_treatment_changes_within_cluster: int pct_clusters_with_variation: float recommendation: str details: str
[docs] def summary(self) -> str: """ Generate human-readable summary of consistency check. Returns ------- str Formatted summary string. """ status = "✓ Consistent" if self.is_consistent else "⚠ Inconsistent" lines = [ "=" * 50, "CLUSTERING CONSISTENCY CHECK", "=" * 50, "", f"Status: {status}", "", self.details, "", f"Recommendation: {self.recommendation}", "=" * 50, ] return "\n".join(lines)
@dataclass class WildClusterBootstrapResult: """ Result of wild cluster bootstrap inference. This class contains the results of wild cluster bootstrap, which provides more reliable inference when the number of clusters is small. Attributes ---------- att : float Point estimate of ATT. se_bootstrap : float Bootstrap standard error. ci_lower : float Lower bound of bootstrap confidence interval. ci_upper : float Upper bound of bootstrap confidence interval. pvalue : float Bootstrap p-value (two-sided). n_clusters : int Number of clusters. n_bootstrap : int Number of bootstrap replications. weight_type : str Type of bootstrap weights used. t_stat_original : float Original t-statistic. t_stats_bootstrap : np.ndarray Bootstrap t-statistics. rejection_rate : float Proportion of bootstrap t-stats exceeding original. """ att: float se_bootstrap: float ci_lower: float ci_upper: float pvalue: float n_clusters: int n_bootstrap: int weight_type: str t_stat_original: float t_stats_bootstrap: Any # np.ndarray rejection_rate: float def summary(self) -> str: """ Generate human-readable summary of bootstrap results. Returns ------- str Formatted summary string. """ sig = "***" if self.pvalue < 0.01 else "**" if self.pvalue < 0.05 else "*" if self.pvalue < 0.1 else "" return ( f"Wild Cluster Bootstrap Results\n" f"{'='*40}\n" f"ATT: {self.att:.4f} {sig}\n" f"Bootstrap SE: {self.se_bootstrap:.4f}\n" f"95% CI: [{self.ci_lower:.4f}, {self.ci_upper:.4f}]\n" f"P-value: {self.pvalue:.4f}\n" f"N clusters: {self.n_clusters}\n" f"N bootstrap: {self.n_bootstrap}\n" f"Weight type: {self.weight_type}\n" f"{'='*40}" ) # ============================================================================= # Helper Functions # ============================================================================= def _validate_clustering_inputs( data: pd.DataFrame, ivar: str, potential_cluster_vars: List[str], gvar: Optional[str], d: Optional[str], ) -> None: """ Validate inputs for clustering diagnostics. Parameters ---------- data : pd.DataFrame Panel data. ivar : str Unit identifier column name. potential_cluster_vars : List[str] List of potential clustering variable names. gvar : str, optional Cohort variable for staggered designs. d : str, optional Treatment indicator for common timing. Raises ------ ValueError If inputs are invalid. """ if not isinstance(data, pd.DataFrame): raise ValueError("data must be a pandas DataFrame") if ivar not in data.columns: raise ValueError(f"Unit variable '{ivar}' not found in data") for var in potential_cluster_vars: if var not in data.columns: raise ValueError(f"Cluster variable '{var}' not found in data") if gvar is not None and gvar not in data.columns: raise ValueError(f"Cohort variable '{gvar}' not found in data") if d is not None and d not in data.columns: raise ValueError(f"Treatment variable '{d}' not found in data") if gvar is None and d is None: raise ValueError("Either gvar or d must be specified") if len(potential_cluster_vars) == 0: raise ValueError("At least one potential cluster variable must be specified") def _analyze_cluster_var( data: pd.DataFrame, ivar: str, cluster_var: str, gvar: Optional[str], d: Optional[str], ) -> ClusterVarStats: """ Analyze a single potential clustering variable. Parameters ---------- data : pd.DataFrame Panel data. ivar : str Unit identifier. cluster_var : str Clustering variable to analyze. gvar : str, optional Cohort variable (staggered). d : str, optional Treatment indicator (common timing). Returns ------- ClusterVarStats Statistics for the clustering variable. """ # Basic cluster statistics. cluster_sizes = data.groupby(cluster_var).size() n_clusters = len(cluster_sizes) # Determine treatment variable. if gvar is not None: # Staggered design: units with non-never-treated gvar values are treated. never_treated_vals = [0, np.inf] treated_mask = ~data[gvar].isin(never_treated_vals) & data[gvar].notna() elif d is not None: treated_mask = data[d] == 1 else: treated_mask = pd.Series(False, index=data.index) # Count treated and control clusters. # Group treated_mask directly by cluster variable to avoid FutureWarning. cluster_has_treated = treated_mask.groupby(data[cluster_var]).any() n_treated_clusters = int(cluster_has_treated.sum()) n_control_clusters = n_clusters - n_treated_clusters # Check if cluster is nested within unit (invalid for clustering). units_per_cluster = data.groupby(cluster_var)[ivar].nunique() n_unique_units = data[ivar].nunique() is_nested_in_unit = (units_per_cluster == 1).all() and n_clusters > n_unique_units # Determine level relative to unit. if is_nested_in_unit: level = ClusteringLevel.LOWER elif n_clusters == n_unique_units: level = ClusteringLevel.SAME else: level = ClusteringLevel.HIGHER # Check if treatment varies within clusters. if gvar is not None: treatment_per_cluster = data.groupby(cluster_var)[gvar].nunique() elif d is not None: treatment_per_cluster = data.groupby(cluster_var)[d].nunique() else: treatment_per_cluster = pd.Series(1, index=data[cluster_var].unique()) treatment_varies = (treatment_per_cluster > 1).any() n_with_variation = int((treatment_per_cluster > 1).sum()) # Compute coefficient of variation of cluster sizes. if cluster_sizes.mean() > 0: cluster_size_cv = float(cluster_sizes.std() / cluster_sizes.mean()) else: cluster_size_cv = 0.0 return ClusterVarStats( var_name=cluster_var, n_clusters=n_clusters, n_treated_clusters=n_treated_clusters, n_control_clusters=n_control_clusters, min_cluster_size=int(cluster_sizes.min()), max_cluster_size=int(cluster_sizes.max()), mean_cluster_size=float(cluster_sizes.mean()), median_cluster_size=float(cluster_sizes.median()), cluster_size_cv=cluster_size_cv, level_relative_to_unit=level, units_per_cluster=float(units_per_cluster.mean()), is_nested_in_unit=is_nested_in_unit, treatment_varies_within_cluster=treatment_varies, n_clusters_with_treatment_variation=n_with_variation ) def _detect_treatment_variation_level( data: pd.DataFrame, ivar: str, potential_cluster_vars: List[str], gvar: Optional[str], d: Optional[str], ) -> str: """ Detect the level at which treatment varies. Returns the name of the variable at which treatment is constant within groups (i.e., the treatment variation level). Parameters ---------- data : pd.DataFrame Panel data. ivar : str Unit identifier. potential_cluster_vars : List[str] List of potential clustering variables. gvar : str, optional Cohort variable. d : str, optional Treatment indicator. Returns ------- str Name of the variable at which treatment varies. """ treatment_var = gvar if gvar is not None else d if treatment_var is None: return "unknown" # Check from highest level (fewest unique values) to lowest level. sorted_vars = sorted(potential_cluster_vars, key=lambda v: data[v].nunique()) for var in sorted_vars: treatment_per_group = data.groupby(var)[treatment_var].nunique() if (treatment_per_group == 1).all(): return var # If treatment varies at all levels, return the unit level. return ivar def _generate_clustering_recommendation( cluster_structure: Dict[str, ClusterVarStats], treatment_level: str, ) -> Tuple[Optional[str], str, List[str]]: """ Generate clustering recommendation based on diagnostics. Parameters ---------- cluster_structure : Dict[str, ClusterVarStats] Statistics for each potential clustering variable. treatment_level : str Detected treatment variation level. Returns ------- Tuple[Optional[str], str, List[str]] Tuple of (recommended_var, reason, warnings). """ warnings = [] # Filter to valid options. valid_options = { var: stats for var, stats in cluster_structure.items() if stats.is_valid_cluster } if not valid_options: return None, "No valid clustering options available.", [ "All potential cluster variables are invalid (nested within units or < 2 clusters)." ] # Prefer clustering at the treatment variation level. if treatment_level in valid_options: stats = valid_options[treatment_level] if stats.n_clusters >= 20: return treatment_level, ( f"Clustering at treatment variation level ({treatment_level}) " f"with {stats.n_clusters} clusters." ), warnings else: warnings.append( f"Treatment varies at {treatment_level} level but only " f"{stats.n_clusters} clusters available." ) # Otherwise, select the option with sufficient clusters and highest reliability. ranked = sorted( valid_options.items(), key=lambda x: (x[1].n_clusters >= 20, x[1].reliability_score), reverse=True ) best_var, best_stats = ranked[0] if best_stats.n_clusters < 20: warnings.append( f"Recommended clustering has only {best_stats.n_clusters} clusters. " f"Consider wild cluster bootstrap for reliable inference." ) if best_stats.treatment_varies_within_cluster: warnings.append( f"Treatment varies within {best_stats.n_clusters_with_treatment_variation} " f"clusters. Standard errors may be conservative." ) reason = ( f"Best available option with {best_stats.n_clusters} clusters " f"(reliability score: {best_stats.reliability_score:.2f})." ) return best_var, reason, warnings def _generate_recommendation_reasons( var: str, stats: ClusterVarStats, diag: ClusteringDiagnostics, ) -> List[str]: """ Generate list of reasons for clustering recommendation. Parameters ---------- var : str Recommended variable name. stats : ClusterVarStats Statistics for the recommended variable. diag : ClusteringDiagnostics Full diagnostics results. Returns ------- List[str] List of reasons. """ reasons = [] # Treatment variation level. if var == diag.treatment_variation_level: reasons.append(f"Treatment varies at {var} level - clustering at this level is appropriate") # Cluster count. if stats.n_clusters >= 30: reasons.append(f"Sufficient clusters ({stats.n_clusters}) for reliable inference") elif stats.n_clusters >= 20: reasons.append(f"Adequate clusters ({stats.n_clusters}) for inference") else: reasons.append(f"Limited clusters ({stats.n_clusters}) - consider wild bootstrap") # Balance. if stats.n_treated_clusters > 0 and stats.n_control_clusters > 0: balance_ratio = min(stats.n_treated_clusters, stats.n_control_clusters) / max(stats.n_treated_clusters, stats.n_control_clusters) if balance_ratio > 0.5: reasons.append(f"Good balance between treated ({stats.n_treated_clusters}) and control ({stats.n_control_clusters}) clusters") # Hierarchy level. if stats.level_relative_to_unit == ClusteringLevel.HIGHER: reasons.append(f"Clustering at higher level than unit of observation") return reasons def _get_alternative_reason(stats: ClusterVarStats) -> str: """ Get reason string for an alternative clustering option. Parameters ---------- stats : ClusterVarStats Statistics for the alternative. Returns ------- str Reason string. """ if stats.n_clusters < 10: return "too few clusters" elif stats.n_clusters < 20: return "marginal cluster count" elif stats.treatment_varies_within_cluster: return "treatment varies within clusters" else: return f"reliability score: {stats.reliability_score:.2f}" def _generate_clustering_warnings( stats: ClusterVarStats, diag: ClusteringDiagnostics, ) -> List[str]: """ Generate warning messages for clustering choice. Parameters ---------- stats : ClusterVarStats Statistics for the recommended variable. diag : ClusteringDiagnostics Full diagnostics results. Returns ------- List[str] List of warning messages. """ warnings = [] if stats.n_clusters < 10: warnings.append( f"Only {stats.n_clusters} clusters - cluster-robust inference may be unreliable" ) elif stats.n_clusters < 20: warnings.append( f"Only {stats.n_clusters} clusters - consider wild cluster bootstrap" ) if stats.cluster_size_cv > 1.0: warnings.append( f"Highly variable cluster sizes (CV={stats.cluster_size_cv:.2f})" ) if stats.treatment_varies_within_cluster: warnings.append( f"Treatment varies within {stats.n_clusters_with_treatment_variation} clusters" ) return warnings def _determine_cluster_level( data: pd.DataFrame, ivar: str, cluster_var: str, ) -> str: """ Determine the level of clustering variable relative to unit. Parameters ---------- data : pd.DataFrame Panel data. ivar : str Unit identifier. cluster_var : str Clustering variable. Returns ------- str Level description: 'higher', 'same', or 'lower'. """ n_units = data[ivar].nunique() n_clusters = data[cluster_var].nunique() # Check how many units belong to each cluster. units_per_cluster = data.groupby(cluster_var)[ivar].nunique() if n_clusters < n_units and (units_per_cluster > 1).any(): return "higher" elif n_clusters == n_units: return "same" else: return "lower" # ============================================================================= # Main Public Functions # =============================================================================
[docs] def diagnose_clustering( data: pd.DataFrame, ivar: str, potential_cluster_vars: List[str], gvar: Optional[str] = None, d: Optional[str] = None, verbose: bool = True, ) -> ClusteringDiagnostics: """ Diagnose clustering structure and recommend clustering level. Analyzes the hierarchical structure of potential clustering variables relative to the unit of observation and treatment assignment. This function helps users choose the appropriate clustering level for standard error estimation in difference-in-differences analysis. Parameters ---------- data : pd.DataFrame Panel data in long format. ivar : str Unit identifier column name. potential_cluster_vars : List[str] List of potential clustering variable column names to evaluate. gvar : str, optional Cohort variable for staggered designs. Use this for staggered adoption designs where treatment timing varies across units. d : str, optional Treatment indicator variable (for common timing). Use this for designs where all treated units receive treatment at the same time. verbose : bool, default True Whether to print diagnostic summary. Returns ------- ClusteringDiagnostics Diagnostic results containing: - cluster_structure: Statistics for each potential clustering variable - recommended_cluster_var: Recommended clustering variable name - recommendation_reason: Explanation for the recommendation - treatment_variation_level: Detected level at which treatment varies - warnings: Warning messages about clustering choices Raises ------ ValueError If inputs are invalid (missing columns, no treatment variable, etc.) Notes ----- When the policy or treatment varies at a level higher than the unit of observation, standard errors should be clustered at the policy variation level to properly account for within-cluster correlation. The function evaluates each potential clustering variable based on: - Number of clusters (more is better, 20-30 minimum recommended) - Balance between treated and control clusters - Whether treatment varies within clusters - Cluster size variation (coefficient of variation) See Also -------- recommend_clustering_level : Get detailed recommendation with alternatives. check_clustering_consistency : Validate clustering choice against treatment. """ # Validate inputs. _validate_clustering_inputs(data, ivar, potential_cluster_vars, gvar, d) # Analyze each potential clustering variable. cluster_structure = {} for var in potential_cluster_vars: stats = _analyze_cluster_var(data, ivar, var, gvar, d) cluster_structure[var] = stats # Detect treatment variation level. treatment_level = _detect_treatment_variation_level( data, ivar, potential_cluster_vars, gvar, d ) # Generate recommendation. recommended_var, reason, warnings = _generate_clustering_recommendation( cluster_structure, treatment_level ) result = ClusteringDiagnostics( cluster_structure=cluster_structure, recommended_cluster_var=recommended_var, recommendation_reason=reason, treatment_variation_level=treatment_level, warnings=warnings ) if verbose: print(result.summary()) return result
[docs] def recommend_clustering_level( data: pd.DataFrame, ivar: str, tvar: str, potential_cluster_vars: List[str], gvar: Optional[str] = None, d: Optional[str] = None, min_clusters: int = 20, verbose: bool = True, ) -> ClusteringRecommendation: """ Recommend optimal clustering level based on data characteristics. This function provides a detailed recommendation for which clustering variable to use, along with confidence scores, alternatives, and guidance on whether wild cluster bootstrap is needed. Algorithm: 1. Analyze each potential cluster variable 2. Detect treatment variation level 3. Filter to valid clustering options 4. Rank by reliability score 5. Check if wild bootstrap is needed (when clusters < min_clusters) Parameters ---------- data : pd.DataFrame Panel data in long format. ivar : str Unit identifier column name. tvar : str Time variable column name. potential_cluster_vars : List[str] List of potential clustering variable column names. gvar : str, optional Cohort/treatment timing variable column name (for staggered designs). d : str, optional Treatment indicator variable (for common timing designs). min_clusters : int, default 20 Minimum recommended number of clusters. If the recommended clustering variable has fewer clusters, wild cluster bootstrap will be recommended. verbose : bool, default True Whether to print recommendation summary. Returns ------- ClusteringRecommendation Recommendation containing: - recommended_var: Recommended clustering variable name - n_clusters: Number of clusters with recommended variable - n_treated_clusters: Number of treated clusters - n_control_clusters: Number of control clusters - confidence: Confidence in recommendation (0-1) - reasons: List of reasons supporting the recommendation - alternatives: Alternative clustering options - warnings: Warning messages - use_wild_bootstrap: Whether to recommend wild cluster bootstrap - wild_bootstrap_reason: Reason for wild bootstrap recommendation Raises ------ ValueError If no valid clustering options are found. Notes ----- The reliability score is computed as a weighted combination of: - Number of clusters (50% weight, saturates at 50 clusters) - Balance of treated/control clusters (30% weight) - Cluster size variation (20% weight, lower CV is better) When the number of clusters is below ``min_clusters``, the function recommends using wild cluster bootstrap for more reliable inference. See Also -------- diagnose_clustering : Get detailed diagnostics for clustering structure. check_clustering_consistency : Check if clustering is consistent with treatment. wild_cluster_bootstrap : Bootstrap inference for small cluster counts. """ # Run diagnostics. diag = diagnose_clustering( data, ivar, potential_cluster_vars, gvar=gvar, d=d, verbose=False ) # Filter to valid options. valid_options = [ (var, stats) for var, stats in diag.cluster_structure.items() if stats.is_valid_cluster ] if not valid_options: raise ValueError( "No valid clustering options found. " "All potential cluster variables are either nested within units " "or have fewer than 2 clusters." ) # Rank by reliability score. ranked = sorted(valid_options, key=lambda x: x[1].reliability_score, reverse=True) # Select the best option. best_var, best_stats = ranked[0] # Generate recommendation reasons. reasons = _generate_recommendation_reasons(best_var, best_stats, diag) # Check if wild bootstrap is needed. use_wild_bootstrap = best_stats.n_clusters < min_clusters wild_reason = None if use_wild_bootstrap: wild_reason = ( f"Only {best_stats.n_clusters} clusters available (< {min_clusters}). " f"Wild cluster bootstrap recommended for reliable inference." ) # Generate alternatives. alternatives = [] for var, stats in ranked[1:3]: # Top 2 alternatives. alternatives.append({ 'var': var, 'n_clusters': stats.n_clusters, 'reliability_score': stats.reliability_score, 'reason': _get_alternative_reason(stats) }) # Generate warnings. warnings = _generate_clustering_warnings(best_stats, diag) result = ClusteringRecommendation( recommended_var=best_var, n_clusters=best_stats.n_clusters, n_treated_clusters=best_stats.n_treated_clusters, n_control_clusters=best_stats.n_control_clusters, confidence=best_stats.reliability_score, reasons=reasons, alternatives=alternatives, warnings=warnings, use_wild_bootstrap=use_wild_bootstrap, wild_bootstrap_reason=wild_reason ) if verbose: print(result.summary()) return result
[docs] def check_clustering_consistency( data: pd.DataFrame, ivar: str, cluster_var: str, gvar: Optional[str] = None, d: Optional[str] = None, verbose: bool = True, ) -> ClusteringConsistencyResult: """ Check if clustering level is consistent with treatment variation level. A consistent clustering choice means: 1. Treatment does not vary within clusters (or varies minimally) 2. Cluster level is at or above the treatment variation level This function helps validate that the chosen clustering variable is appropriate for the treatment assignment mechanism. Parameters ---------- data : pd.DataFrame Panel data. ivar : str Unit identifier. cluster_var : str Clustering variable to check. gvar : str, optional Treatment timing variable (for staggered designs). d : str, optional Treatment indicator (for common timing designs). verbose : bool, default True Whether to print results. Returns ------- ClusteringConsistencyResult Consistency check results containing: - is_consistent: Whether clustering level is consistent - treatment_variation_level: Detected treatment variation level - cluster_level: Level of the clustering variable - n_clusters: Number of clusters - n_treatment_changes_within_cluster: Clusters with treatment variation - pct_clusters_with_variation: Percentage with variation - recommendation: Suggested action if inconsistent - details: Detailed explanation Raises ------ ValueError If inputs are invalid. Notes ----- A clustering choice is considered consistent if: - Less than 5% of clusters have within-cluster treatment variation - The cluster level is at the same level or higher than the unit If treatment varies within clusters, standard errors may be conservative (too large), leading to under-rejection of the null hypothesis. See Also -------- diagnose_clustering : Get detailed diagnostics for clustering structure. recommend_clustering_level : Get recommendation for clustering level. """ # Validate inputs. if cluster_var not in data.columns: raise ValueError(f"Cluster variable '{cluster_var}' not found in data") if ivar not in data.columns: raise ValueError(f"Unit variable '{ivar}' not found in data") if gvar is None and d is None: raise ValueError("Either gvar or d must be specified") treatment_var = gvar if gvar is not None else d if treatment_var not in data.columns: raise ValueError(f"Treatment variable '{treatment_var}' not found in data") # Check if treatment varies within clusters. cluster_treatment = data.groupby(cluster_var)[treatment_var].nunique() n_clusters_with_variation = int((cluster_treatment > 1).sum()) n_clusters = len(cluster_treatment) pct_with_variation = n_clusters_with_variation / n_clusters * 100 if n_clusters > 0 else 0 # Determine treatment variation level. treatment_level = _detect_treatment_variation_level( data, ivar, [cluster_var], gvar, d ) # Determine cluster level. cluster_level = _determine_cluster_level(data, ivar, cluster_var) # Check consistency. is_consistent = ( pct_with_variation < 5 and # Less than 5% of clusters have variation. cluster_level in ['same', 'higher'] ) # Generate recommendation. if is_consistent: recommendation = "Clustering choice is appropriate." else: if pct_with_variation >= 5: recommendation = ( f"Treatment varies within {pct_with_variation:.1f}% of clusters. " f"Consider clustering at a higher level where treatment is constant." ) else: recommendation = ( f"Cluster level ({cluster_level}) may be inappropriate. " f"Consider clustering at the treatment variation level ({treatment_level})." ) # Generate details. details = ( f"Analyzed {n_clusters} clusters.\n" f"Treatment varies within {n_clusters_with_variation} clusters " f"({pct_with_variation:.1f}%).\n" f"Treatment variation level: {treatment_level}\n" f"Cluster level: {cluster_level}" ) result = ClusteringConsistencyResult( is_consistent=is_consistent, treatment_variation_level=treatment_level, cluster_level=cluster_level, n_clusters=n_clusters, n_treatment_changes_within_cluster=n_clusters_with_variation, pct_clusters_with_variation=pct_with_variation, recommendation=recommendation, details=details ) if verbose: print(result.summary()) return result