|  | 
|  | 1 | +# ============================================================================== | 
|  | 2 | +# Author(s): Sharim Khan, Gabriel Lee | 
|  | 3 | +# NetID(s): sharimk2, gjlee4 | 
|  | 4 | +# Paper title: | 
|  | 5 | +#           Explaining A Machine Learning Decision to Physicians via Counterfactuals | 
|  | 6 | +# Paper link: https://arxiv.org/abs/2306.06325 | 
|  | 7 | +# Description: This file defines the Counterfactual Variational Autoencoder (CFVAE) | 
|  | 8 | +#              model, which reconstructs input data while generating counterfactual | 
|  | 9 | +#              examples that flip the prediction of a frozen classifier. | 
|  | 10 | +# ============================================================================== | 
|  | 11 | + | 
|  | 12 | +from typing import List, Dict | 
|  | 13 | + | 
|  | 14 | +import torch | 
|  | 15 | +import torch.nn as nn | 
|  | 16 | +import torch.nn.functional as F | 
|  | 17 | + | 
|  | 18 | +from pyhealth.models import BaseModel | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +class CFVAE(BaseModel): | 
|  | 22 | +    """Counterfactual Variational Autoencoder (CFVAE) for binary prediction tasks. | 
|  | 23 | +
 | 
|  | 24 | +    This is a parametrized version of the CFVAE model described by Nagesh et al. | 
|  | 25 | +
 | 
|  | 26 | +    The CFVAE learns to reconstruct inputs while generating counterfactual samples | 
|  | 27 | +    that flip the output of a fixed, externally trained binary classifier. It combines | 
|  | 28 | +    VAE reconstruction and KL divergence losses with a classifier-based loss. | 
|  | 29 | +
 | 
|  | 30 | +    NOTE: A binary classifier MUST be passed as an argument. | 
|  | 31 | +    NOTE: The sparsity constraint should be implemented in the training loop. | 
|  | 32 | +
 | 
|  | 33 | +    Attributes: | 
|  | 34 | +        feature_keys: Feature keys used as inputs. | 
|  | 35 | +        label_keys: A list containing the label key. | 
|  | 36 | +        mode: Task mode (must be 'binary'). | 
|  | 37 | +        latent_dim: Latent dimensionality of the VAE. | 
|  | 38 | +        external_classifier: Frozen external classifier for guiding counterfactuals. | 
|  | 39 | +        enc1: First encoder layer. | 
|  | 40 | +        enc2: Layer projecting to latent mean and log-variance. | 
|  | 41 | +        dec1: First decoder layer. | 
|  | 42 | +        dec2: Layer projecting to reconstructed input space. | 
|  | 43 | +
 | 
|  | 44 | +    Example: | 
|  | 45 | +        cfvae = CFVAE( | 
|  | 46 | +            dataset=samples, | 
|  | 47 | +            feature_keys=["labs"], | 
|  | 48 | +            label_key="mortality", | 
|  | 49 | +            mode="binary", | 
|  | 50 | +            feat_dim=27, | 
|  | 51 | +            latent_dim=32, | 
|  | 52 | +            hidden_dim=64, | 
|  | 53 | +            external_classifier=frozen_classifier | 
|  | 54 | +        ) | 
|  | 55 | +    """ | 
|  | 56 | + | 
|  | 57 | +    def __init__( | 
|  | 58 | +        self, | 
|  | 59 | +        dataset, | 
|  | 60 | +        feature_keys: List[str], | 
|  | 61 | +        label_key: str, | 
|  | 62 | +        mode: str, | 
|  | 63 | +        feat_dim: int, | 
|  | 64 | +        latent_dim: int = 32, | 
|  | 65 | +        hidden_dim: int = 64, | 
|  | 66 | +        external_classifier: nn.Module = None, | 
|  | 67 | +    ): | 
|  | 68 | +        """ | 
|  | 69 | +        Initializes the CFVAE model and freezes the external classifier. | 
|  | 70 | +
 | 
|  | 71 | +        Args: | 
|  | 72 | +            dataset: PyHealth-compatible dataset object. | 
|  | 73 | +            feature_keys: List of input feature keys. | 
|  | 74 | +            label_key: Output label key (must be binary). | 
|  | 75 | +            mode: Task mode ('binary' only supported). | 
|  | 76 | +            feat_dim: Input feature dimensionality. | 
|  | 77 | +            latent_dim: Latent space dimensionality. | 
|  | 78 | +            hidden_dim: Hidden layer size in encoder/decoder. | 
|  | 79 | +            external_classifier: Frozen binary classifier to guide counterfactuals. | 
|  | 80 | +        """ | 
|  | 81 | +        super().__init__(dataset) | 
|  | 82 | +        self.feature_keys = feature_keys | 
|  | 83 | +        self.label_keys = [label_key] | 
|  | 84 | +        self.mode = mode | 
|  | 85 | + | 
|  | 86 | +        assert mode == "binary", "Only binary classification is supported." | 
|  | 87 | +        assert external_classifier is not None, "external_classifier must be provided." | 
|  | 88 | + | 
|  | 89 | +        self.latent_dim = latent_dim | 
|  | 90 | +        self.external_classifier = external_classifier.eval() | 
|  | 91 | +        for param in self.external_classifier.parameters(): | 
|  | 92 | +            param.requires_grad = False | 
|  | 93 | + | 
|  | 94 | +        self.enc1 = nn.Sequential( | 
|  | 95 | +            nn.Linear(feat_dim, hidden_dim), | 
|  | 96 | +            nn.LayerNorm(hidden_dim), | 
|  | 97 | +            nn.ReLU() | 
|  | 98 | +        ) | 
|  | 99 | +        self.enc2 = nn.Linear(hidden_dim, 2 * latent_dim) | 
|  | 100 | + | 
|  | 101 | +        self.dec1 = nn.Sequential( | 
|  | 102 | +            nn.Linear(latent_dim + 2, hidden_dim), | 
|  | 103 | +            nn.LayerNorm(hidden_dim), | 
|  | 104 | +            nn.ReLU() | 
|  | 105 | +        ) | 
|  | 106 | +        self.dec2 = nn.Linear(hidden_dim, feat_dim) | 
|  | 107 | + | 
|  | 108 | +    def reparameterize( | 
|  | 109 | +        self, mu: torch.Tensor, log_var: torch.Tensor | 
|  | 110 | +    ) -> torch.Tensor: | 
|  | 111 | +        """ | 
|  | 112 | +        Applies the reparameterization trick to sample z from Gaussian N. | 
|  | 113 | +
 | 
|  | 114 | +        Args: | 
|  | 115 | +            mu: Mean of the latent distribution, shape (B, latent_dim). | 
|  | 116 | +            log_var: Log variance of the latent distribution, shape (B, latent_dim). | 
|  | 117 | +
 | 
|  | 118 | +        Returns: | 
|  | 119 | +            z: Sampled latent variable, shape (B, latent_dim). | 
|  | 120 | +        """ | 
|  | 121 | +        std = torch.exp(0.5 * log_var) | 
|  | 122 | +        eps = torch.randn_like(std) | 
|  | 123 | +        return mu + eps * std | 
|  | 124 | + | 
|  | 125 | +    def forward(self, **kwargs) -> Dict[str, torch.Tensor]: | 
|  | 126 | +        """ | 
|  | 127 | +        Forward pass for CFVAE: encodes input, reparameterizes, decodes with flipped | 
|  | 128 | +        labels, and computes reconstruction, KL, and classifier-based losses. | 
|  | 129 | +
 | 
|  | 130 | +        Args: | 
|  | 131 | +            kwargs: Dict of inputs including: | 
|  | 132 | +                - feature_keys[0]: Input tensor (B, feat_dim) | 
|  | 133 | +                - label_keys[0]: Ground truth label tensor (B,) | 
|  | 134 | +
 | 
|  | 135 | +        Returns: | 
|  | 136 | +            Dictionary containing: | 
|  | 137 | +                - loss: Total training loss (recon + KL + classifier disagreement). | 
|  | 138 | +                - y_prob: Classifier output probabilities for reconstructed inputs. | 
|  | 139 | +                - y_true: Ground truth labels. | 
|  | 140 | +        """ | 
|  | 141 | +        x = kwargs[self.feature_keys[0]].to(self.device) | 
|  | 142 | +        y = kwargs[self.label_keys[0]].to(self.device) | 
|  | 143 | + | 
|  | 144 | +        # Encode inputs | 
|  | 145 | +        h = self.enc1(x) | 
|  | 146 | +        h = self.enc2(h).view(-1, 2, self.latent_dim) | 
|  | 147 | +        mu, log_var = h[:, 0, :], h[:, 1, :] | 
|  | 148 | +        z = self.reparameterize(mu, log_var) | 
|  | 149 | + | 
|  | 150 | +        # Flip labels to condition decoder on opposite class (counterfactual) | 
|  | 151 | +        y_cf = 1 - y | 
|  | 152 | +        y_cf_onehot = F.one_hot(y_cf.view(-1).long(), num_classes=2).float() | 
|  | 153 | +        z_cond = torch.cat([z, y_cf_onehot], dim=1) | 
|  | 154 | + | 
|  | 155 | +        h_dec = self.dec1(z_cond) | 
|  | 156 | +        x_recon = torch.sigmoid(self.dec2(h_dec)) | 
|  | 157 | + | 
|  | 158 | +        # Evaluate external classifier on counterfactual | 
|  | 159 | +        with torch.no_grad(): | 
|  | 160 | +            logits = self.external_classifier(x_recon) | 
|  | 161 | + | 
|  | 162 | +        # Compute losses | 
|  | 163 | +        clf_loss = self.get_loss_function()(logits, y) | 
|  | 164 | +        recon_loss = F.mse_loss(x_recon, x, reduction="mean") | 
|  | 165 | +        kld_loss = -0.5 * torch.mean( | 
|  | 166 | +            1 + log_var - mu.pow(2) - log_var.exp() | 
|  | 167 | +        ) | 
|  | 168 | +        total_loss = recon_loss + kld_loss + clf_loss | 
|  | 169 | + | 
|  | 170 | +        return { | 
|  | 171 | +            "loss": total_loss, | 
|  | 172 | +            "y_prob": self.prepare_y_prob(logits), | 
|  | 173 | +            "y_true": y, | 
|  | 174 | +        } | 
|  | 175 | + | 
0 commit comments