Source code for lwdid.visualization

"""
Visualization utilities for difference-in-differences analysis.

This module provides plotting functions for visualizing transformed outcomes
in panel data difference-in-differences settings. The primary use case is
comparing the trajectory of residualized outcomes between treated units
(or their group average) and the control group mean across time periods.

The visualization functions support both single treated unit analysis and
aggregated treatment group comparisons. Plots display pre-treatment fit
quality and post-intervention treatment effect gaps.

Notes
-----
Requires matplotlib >= 3.3. Raises VisualizationError if matplotlib is
not available when plot generation is requested.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pandas as pd

from .exceptions import InvalidParameterError, VisualizationError

if TYPE_CHECKING:
    from typing import Any


def _resolve_gid(
    data: pd.DataFrame,
    ivar_var: str,
    d_var: str,
    gid: str | int
) -> int:
    """
    Resolve a user-provided unit identifier to the internal representation.

    Maps user-specified unit identifiers to the internal numeric identifiers
    used in the transformed data. Handles string-to-numeric conversions and
    validates that the resolved identifier corresponds to a treated unit.

    Parameters
    ----------
    data : pd.DataFrame
        Panel data containing the unit identifier column.
    ivar_var : str
        Name of the unit identifier column in the DataFrame.
    d_var : str
        Name of the binary treatment indicator column (1 = treated).
    gid : str or int
        User-specified unit identifier to resolve. Accepts both string
        and numeric formats.

    Returns
    -------
    int or str
        Resolved unit identifier matching the dtype of the ivar column.

    Raises
    ------
    InvalidParameterError
        If the unit identifier is not found in the data or does not
        correspond to a treated unit.

    Notes
    -----
    The resolution process handles two scenarios: (1) when an id_mapping
    dictionary exists in data.attrs (from string-to-numeric conversion
    during preprocessing), and (2) direct matching against the ivar column.
    Type coercion is applied when the input type differs from the column dtype.
    """
    original_gid = gid
    
    # Check for id_mapping from preprocessing (string IDs converted to numeric)
    mapping = data.attrs.get('id_mapping', None)
    if mapping and 'original_to_numeric' in mapping:
        # Attempt lookup in the original-to-numeric mapping first
        gid_str = str(gid) if not isinstance(gid, str) else gid
        gid_num = mapping['original_to_numeric'].get(gid_str)
        if gid_num is not None:
            gid_resolved = gid_num
        else:
            # Mapping lookup failed; fall back to direct column matching
            gid_to_match = gid
            # Coerce type to match column dtype for comparison
            if isinstance(gid, str) and pd.api.types.is_numeric_dtype(data[ivar_var]):
                try:
                    gid_to_match = pd.to_numeric(gid)
                except (ValueError, TypeError):
                    pass
            elif not isinstance(gid, str) and pd.api.types.is_string_dtype(data[ivar_var]):
                gid_to_match = str(gid)

            mask = (data[ivar_var] == gid_to_match)
            if not mask.any():
                raise InvalidParameterError(f"gid '{original_gid}' not found")
            gid_resolved = data.loc[mask, ivar_var].iloc[0]
    else:
        # No id_mapping; match directly against the ivar column
        gid_to_match = gid
        # Coerce type to match column dtype for comparison
        if isinstance(gid, str) and pd.api.types.is_numeric_dtype(data[ivar_var]):
            try:
                gid_to_match = pd.to_numeric(gid)
            except (ValueError, TypeError):
                pass
        elif not isinstance(gid, str) and pd.api.types.is_string_dtype(data[ivar_var]):
            gid_to_match = str(gid)

        mask = (data[ivar_var] == gid_to_match)
        if not mask.any():
            raise InvalidParameterError(f"gid '{original_gid}' not found")
        gid_resolved = data.loc[mask, ivar_var].iloc[0]

    # Validate that the resolved unit is a treated unit (d=1)
    unit_rows = data[data[ivar_var] == gid_resolved]
    if len(unit_rows) == 0:
        raise InvalidParameterError(f"gid '{original_gid}' not found")
    d_max = int(unit_rows[d_var].max())
    if d_max != 1:
        raise InvalidParameterError(f"'{original_gid}' is not a treated unit")

    return gid_resolved


[docs] def prepare_plot_data( data: pd.DataFrame, ydot_var: str, d_var: str, tindex_var: str, ivar_var: str, gid: str | int | None, tpost1: int, Tmax: int, period_labels: dict[int, str], ) -> dict[str, Any]: """ Prepare data structures for plotting transformed outcomes. Computes control group means and treated unit (or group average) series across all time periods. The output dictionary contains all necessary data for generating comparative time series plots. Parameters ---------- data : pd.DataFrame Transformed panel data containing residualized outcomes. ydot_var : str Name of the column containing the residualized outcome variable (unit-specific mean or trend removed). d_var : str Name of the binary treatment indicator column (1 = treated). tindex_var : str Name of the time period index column. ivar_var : str Name of the unit identifier column. gid : str, int, or None Unit identifier for a specific treated unit to plot. If None, computes the average across all treated units. tpost1 : int First post-treatment time period (intervention point). Tmax : int Final time period in the panel. period_labels : dict of {int: str} Mapping from time index values to display labels for the x-axis. Returns ------- dict Dictionary containing: - ``time`` : list of int Time period indices from 1 to Tmax. - ``control_mean`` : list of float Control group mean of the residualized outcome for each period. - ``treated_series`` : list of float Treated unit or group average of the residualized outcome. - ``intervention_point`` : int First post-treatment period for the vertical intervention line. - ``treated_label`` : str Label for the treated series in the plot legend. - ``period_labels`` : dict Time index to label mapping for x-axis tick labels. Raises ------ VisualizationError If required columns are missing from the data. InvalidParameterError If gid is specified but not found or not a treated unit. Notes ----- The returned dictionary provides all data needed by :func:`plot_results`. Control group means are computed by averaging the residualized outcome across all units with d=0 in each period. Treated series is either a single unit trajectory or the group average across all units with d=1. See Also -------- plot_results : Generate the visualization from prepared data. """ # Validate that all required columns exist required = {ydot_var, d_var, tindex_var, ivar_var} missing = required - set(data.columns) if missing: raise VisualizationError(f"Missing required columns: {sorted(missing)}") # Create time index sequence (1-indexed to match tindex convention) time = list(range(1, int(Tmax) + 1)) # Compute period-wise mean of residualized outcome for control units control_mean_series = ( data[data[d_var] == 0] .groupby(tindex_var)[ydot_var] .mean() .reindex(time) ) if gid is not None: # Plot single treated unit trajectory gid_resolved = _resolve_gid(data, ivar_var, d_var, gid) unit = data[data[ivar_var] == gid_resolved] treated_series = ( unit.set_index(tindex_var)[ydot_var] .reindex(time) ) treated_label = f"Unit {gid}" else: # Plot average across all treated units treated_series = ( data[data[d_var] == 1] .groupby(tindex_var)[ydot_var] .mean() .reindex(time) ) treated_label = "Treated (Average)" return { 'time': time, 'control_mean': control_mean_series.tolist(), 'treated_series': treated_series.tolist(), 'intervention_point': int(tpost1), 'treated_label': treated_label, 'period_labels': period_labels, }
[docs] def plot_results( plot_data: dict[str, Any], graph_options: dict[str, Any] | None = None, ): """ Generate a time series plot comparing treated and control outcomes. Creates a matplotlib figure displaying the residualized outcome trajectories for treated units (or their average) and control group mean across all time periods. A vertical line marks the intervention point. Parameters ---------- plot_data : dict Data dictionary from :func:`prepare_plot_data` containing time indices, outcome series, and labeling information. graph_options : dict, optional Customization options for the plot appearance: - ``figsize`` : tuple of (width, height), default (10, 6) - ``title`` : str or None, plot title - ``xlabel`` : str or None, x-axis label - ``ylabel`` : str, y-axis label, default 'Residualized Outcome' - ``legend_loc`` : str, legend position, default 'best' - ``dpi`` : int, figure resolution, default 100 - ``savefig`` : str or None, file path to save the figure Returns ------- matplotlib.figure.Figure Generated matplotlib figure object. Raises ------ VisualizationError If matplotlib is not installed. See Also -------- prepare_plot_data : Prepare the data dictionary for plotting. Notes ----- The plot displays three visual elements: (1) a dashed blue line for control group mean trajectory, (2) a solid red line for treated unit or group average trajectory, and (3) a vertical dashed black line marking the intervention point. Pre-treatment periods appear to the left of the intervention line, enabling visual assessment of parallel trends and pre-treatment fit quality. """ # Lazy import to allow package usage without matplotlib try: import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator except ImportError as exc: raise VisualizationError( 'Install required dependencies: matplotlib>=3.3.' ) from exc # Default plot configuration opts = { 'figsize': (10, 6), 'title': None, 'xlabel': None, 'ylabel': 'Residualized Outcome', 'legend_loc': 'best', 'dpi': 100, 'savefig': None, } if graph_options: opts.update(graph_options) # Extract plot data components time = plot_data['time'] ctrl = plot_data['control_mean'] trt = plot_data['treated_series'] tpost1 = plot_data['intervention_point'] tlabel = plot_data['treated_label'] period_labels = plot_data.get('period_labels', {}) fig, ax = plt.subplots(figsize=opts['figsize'], dpi=opts['dpi']) # Plot outcome trajectories ax.plot(time, ctrl, linestyle='--', color='blue', linewidth=1.5, label='Control') ax.plot(time, trt, linestyle='-', color='red', linewidth=2.0, label=tlabel) ax.axvline(x=tpost1, linestyle='--', color='black', linewidth=1.0, alpha=0.7, label='Intervention') # Configure x-axis with period labels (rotated for readability) ax.set_xticks(time) ax.set_xticklabels([period_labels.get(t, str(t)) for t in time], rotation=45, ha='right') if opts['xlabel'] is not None: ax.set_xlabel(opts['xlabel']) ax.set_ylabel(opts['ylabel']) ax.yaxis.set_major_locator(MaxNLocator(nbins='auto')) if opts['title']: ax.set_title(opts['title']) ax.legend(loc=opts['legend_loc'], frameon=True, shadow=True) fig.tight_layout() if opts['savefig']: fig.savefig(opts['savefig'], dpi=opts['dpi']) return fig