Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .transformer import Transformer, TransformerLayer
from .transformers_model import TransformersModel
from .vae import VAE
from .cfvae import CFVAE
175 changes: 175 additions & 0 deletions pyhealth/models/cfvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# ==============================================================================
# Author(s): Sharim Khan, Gabriel Lee
# NetID(s): sharimk2, gjlee4
# Paper title:
# Explaining A Machine Learning Decision to Physicians via Counterfactuals
# Paper link: https://arxiv.org/abs/2306.06325
# Description: This file defines the Counterfactual Variational Autoencoder (CFVAE)
# model, which reconstructs input data while generating counterfactual
# examples that flip the prediction of a frozen classifier.
# ==============================================================================

from typing import List, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F

from pyhealth.models import BaseModel


class CFVAE(BaseModel):
"""Counterfactual Variational Autoencoder (CFVAE) for binary prediction tasks.

This is a parametrized version of the CFVAE model described by Nagesh et al.

The CFVAE learns to reconstruct inputs while generating counterfactual samples
that flip the output of a fixed, externally trained binary classifier. It combines
VAE reconstruction and KL divergence losses with a classifier-based loss.

NOTE: A binary classifier MUST be passed as an argument.
NOTE: The sparsity constraint should be implemented in the training loop.

Attributes:
feature_keys: Feature keys used as inputs.
label_keys: A list containing the label key.
mode: Task mode (must be 'binary').
latent_dim: Latent dimensionality of the VAE.
external_classifier: Frozen external classifier for guiding counterfactuals.
enc1: First encoder layer.
enc2: Layer projecting to latent mean and log-variance.
dec1: First decoder layer.
dec2: Layer projecting to reconstructed input space.

Example:
cfvae = CFVAE(
dataset=samples,
feature_keys=["labs"],
label_key="mortality",
mode="binary",
feat_dim=27,
latent_dim=32,
hidden_dim=64,
external_classifier=frozen_classifier
)
"""

def __init__(
self,
dataset,
feature_keys: List[str],
label_key: str,
mode: str,
feat_dim: int,
latent_dim: int = 32,
hidden_dim: int = 64,
external_classifier: nn.Module = None,
):
"""
Initializes the CFVAE model and freezes the external classifier.

Args:
dataset: PyHealth-compatible dataset object.
feature_keys: List of input feature keys.
label_key: Output label key (must be binary).
mode: Task mode ('binary' only supported).
feat_dim: Input feature dimensionality.
latent_dim: Latent space dimensionality.
hidden_dim: Hidden layer size in encoder/decoder.
external_classifier: Frozen binary classifier to guide counterfactuals.
"""
super().__init__(dataset)
self.feature_keys = feature_keys
self.label_keys = [label_key]
self.mode = mode

assert mode == "binary", "Only binary classification is supported."
assert external_classifier is not None, "external_classifier must be provided."

self.latent_dim = latent_dim
self.external_classifier = external_classifier.eval()
for param in self.external_classifier.parameters():
param.requires_grad = False

self.enc1 = nn.Sequential(
nn.Linear(feat_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU()
)
self.enc2 = nn.Linear(hidden_dim, 2 * latent_dim)

self.dec1 = nn.Sequential(
nn.Linear(latent_dim + 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU()
)
self.dec2 = nn.Linear(hidden_dim, feat_dim)

def reparameterize(
self, mu: torch.Tensor, log_var: torch.Tensor
) -> torch.Tensor:
"""
Applies the reparameterization trick to sample z from Gaussian N.

Args:
mu: Mean of the latent distribution, shape (B, latent_dim).
log_var: Log variance of the latent distribution, shape (B, latent_dim).

Returns:
z: Sampled latent variable, shape (B, latent_dim).
"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std

def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""
Forward pass for CFVAE: encodes input, reparameterizes, decodes with flipped
labels, and computes reconstruction, KL, and classifier-based losses.

Args:
kwargs: Dict of inputs including:
- feature_keys[0]: Input tensor (B, feat_dim)
- label_keys[0]: Ground truth label tensor (B,)

Returns:
Dictionary containing:
- loss: Total training loss (recon + KL + classifier disagreement).
- y_prob: Classifier output probabilities for reconstructed inputs.
- y_true: Ground truth labels.
"""
x = kwargs[self.feature_keys[0]].to(self.device)
y = kwargs[self.label_keys[0]].to(self.device)

# Encode inputs
h = self.enc1(x)
h = self.enc2(h).view(-1, 2, self.latent_dim)
mu, log_var = h[:, 0, :], h[:, 1, :]
z = self.reparameterize(mu, log_var)

# Flip labels to condition decoder on opposite class (counterfactual)
y_cf = 1 - y
y_cf_onehot = F.one_hot(y_cf.view(-1).long(), num_classes=2).float()
z_cond = torch.cat([z, y_cf_onehot], dim=1)

h_dec = self.dec1(z_cond)
x_recon = torch.sigmoid(self.dec2(h_dec))

# Evaluate external classifier on counterfactual
with torch.no_grad():
logits = self.external_classifier(x_recon)

# Compute losses
clf_loss = self.get_loss_function()(logits, y)
recon_loss = F.mse_loss(x_recon, x, reduction="mean")
kld_loss = -0.5 * torch.mean(
1 + log_var - mu.pow(2) - log_var.exp()
)
total_loss = recon_loss + kld_loss + clf_loss

return {
"loss": total_loss,
"y_prob": self.prepare_y_prob(logits),
"y_true": y,
}

190 changes: 190 additions & 0 deletions pyhealth/unittests/test_cfvae_mortality_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# ==============================================================================
# Author(s): Sharim Khan, Gabriel Lee
# NetID(s): sharimk2, gjlee4
# Paper title:
# Explaining A Machine Learning Decision to Physicians via Counterfactuals
# Paper link: https://arxiv.org/abs/2306.06325
# Description: Test script to train and evaluate a Counterfactual VAE (CFVAE) on
# MIMIC-IV for mortality prediction using PyHealth, including training
# a frozen dummy classifier and then CFVAE with that classifier.
# ==============================================================================

import logging
import os
import sys
from typing import Any

import torch
import torch.nn as nn

# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Add parent directory to sys.path for relative imports
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(os.path.dirname(current_dir))
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)


def test_cfvae_mortality_prediction_mimic4() -> None:
"""Trains a CFVAE model on MIMIC-IV demo data with a frozen dummy classifier.

Steps:
- Load and preprocess MIMIC-IV lab data.
- Train a binary classifier on in-hospital mortality.
- Freeze the classifier.
- Train a CFVAE model to produce counterfactuals.
- Evaluate CFVAE on test data.
"""
logger.info("===== Starting CFVAE Unit Test =====")
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import InHospitalMortalityMIMIC4
from pyhealth.datasets import split_by_sample, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.models import BaseModel, CFVAE

# Load MIMIC-IV demo dataset
dataset = MIMIC4Dataset(
ehr_root="https://physionet.org/files/mimic-iv-demo/2.2/",
ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
)

task = InHospitalMortalityMIMIC4()
samples = dataset.set_task(task)
logger.info(f"===== Loaded {len(samples)} samples. ===== ")

# Preprocessing: mean over time, normalize across samples
logger.info("===== Preprocessing samples (mean over time) =====")
for sample in samples:
sample["labs"] = torch.mean(sample["labs"], dim=0)

labs_tensor = torch.stack([s["labs"] for s in samples])
feature_mean = labs_tensor.mean(dim=0)
feature_std = labs_tensor.std(dim=0) + 1e-6

for sample in samples:
sample["labs"] = (sample["labs"] - feature_mean) / feature_std

# Split data
train_dataset, val_dataset, test_dataset = split_by_sample(
dataset=samples,
ratios=[0.7, 0.1, 0.2]
)

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

logger.info("===== Stage 1: Train the dummy classifier =====")

class DummyClassifier(nn.Module):
"""Simple feedforward binary classifier."""

def __init__(self, input_dim: int = 27, hidden_dim: int = 64):
"""
Args:
input_dim: Dimension of input feature vector.
hidden_dim: Size of hidden layer.
"""
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.

Args:
x: Tensor of shape [batch_size, input_dim].

Returns:
Output logits as a tensor of shape [batch_size, 1].
"""
return self.model(x)

class WrappedClassifier(BaseModel):
"""Wraps a PyTorch classifier into the PyHealth BaseModel interface."""

def __init__(self, dataset: Any, model: nn.Module):
"""
Args:
dataset: PyHealth dataset object.
model: PyTorch model to be wrapped.
"""
super().__init__(dataset)
self.model = model
self.mode = self.dataset.output_schema[self.label_keys[0]]

def forward(self, **kwargs) -> dict:
"""Forward pass and loss computation.

Args:
kwargs: Dict containing "labs" and "mortality".

Returns:
Dictionary with keys "loss", "y_prob", and "y_true".
"""
x = kwargs[self.feature_keys[0]].to(self.device)
y = kwargs[self.label_keys[0]].to(self.device)
logits = self.model(x)
loss = self.get_loss_function()(logits, y)
y_prob = self.prepare_y_prob(logits)
return {
"loss": loss,
"y_prob": y_prob,
"y_true": y
}

clf = DummyClassifier(input_dim=27)
wrapped_model = WrappedClassifier(dataset=samples, model=clf)

trainer = Trainer(model=wrapped_model, metrics=["roc_auc", "accuracy"])
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=5,
monitor="roc_auc"
)

logger.info("===== Freezing the classifier... =====")
clf.eval()
for param in clf.parameters():
param.requires_grad = False

logger.info("===== Stage 2: Train CFVAE with frozen classifier =====")

cfvae_model = CFVAE(
dataset=samples,
feature_keys=["labs"],
label_key="mortality",
mode="binary",
feat_dim=27,
latent_dim=32,
hidden_dim=64,
external_classifier=clf
)

cfvae_trainer = Trainer(model=cfvae_model, metrics=["roc_auc", "accuracy"])
cfvae_trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=10,
monitor="roc_auc",
optimizer_params={"lr": 1e-3}
)

logger.info("===== Test set evaluation =====")
print(cfvae_trainer.evaluate(test_dataloader))
logger.info("===== Successfully completed CFVAE unit test! =====")


if __name__ == "__main__":
test_cfvae_mortality_prediction_mimic4()