[CS598 DLH] Counterfactual VAE (CF-VAE) for binary healthcare prediction tasks #404
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Contributor Info
Name: Sharim Khan, Gabriel Lee
NetID: sharimk2, gjlee4
Paper: Explaining a Machine Learning Decision to Physicians via Counterfactuals
Paper Link: https://arxiv.org/abs/2306.06325
Type of Contribution
Model: New model addition
Task: Counterfactual VAE (CF-VAE) for binary healthcare prediction tasks
High-Level Description
This PR contributes a PyHealth-compatible implementation of the Counterfactual Variational Autoencoder (CFVAE) model, which is specifically designed to generate counterfactuals for binary prediction tasks. Specifically, the model is composed of:
The model output should generate a counterfactual based on the input and provided classifier.
This implementation is inspired by the model proposed in:
Explaining a Machine Learning Decision to Physicians via Counterfactuals
Nagesh et al., 2023 (arXiv link)
Testing
To aid in testing the model, we provide the below Google Colab Notebook.
https://colab.research.google.com/drive/1gRaE6QDYfgjhopzEaAgCy45WaEL9ON1D
The unit test is recommended to be run in the notebook. The notebook will handle reproducing the environment, such that the test case can be run as:
Unit Test Output
2025-05-06 07:32:07,322 - __main__ - INFO - ===== Starting CFVAE Unit Test ===== 2025-05-06 07:32:08,709 - numexpr.utils - INFO - NumExpr defaulting to 2 threads. Memory usage Starting MIMIC4Dataset init: 823.0 MB 2025-05-06 07:32:20,081 - pyhealth.datasets.mimic4 - INFO - Memory usage Starting MIMIC4Dataset init: 823.0 MB Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: False) 2025-05-06 07:32:20,081 - pyhealth.datasets.mimic4 - INFO - Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: False) Using default EHR config: /content/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml 2025-05-06 07:32:20,081 - pyhealth.datasets.mimic4 - INFO - Using default EHR config: /content/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml Memory usage Before initializing mimic4_ehr: 823.0 MB 2025-05-06 07:32:20,081 - pyhealth.datasets.mimic4 - INFO - Memory usage Before initializing mimic4_ehr: 823.0 MB Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False) 2025-05-06 07:32:20,090 - pyhealth.datasets.base_dataset - INFO - Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False) Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz 2025-05-06 07:32:20,090 - pyhealth.datasets.base_dataset - INFO - Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz 2025-05-06 07:32:21,556 - pyhealth.datasets.base_dataset - INFO - Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz 2025-05-06 07:32:22,449 - pyhealth.datasets.base_dataset - INFO - Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz 2025-05-06 07:32:23,331 - pyhealth.datasets.base_dataset - INFO - Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz Scanning table: prescriptions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/prescriptions.csv.gz 2025-05-06 07:32:23,766 - pyhealth.datasets.base_dataset - INFO - Scanning table: prescriptions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/prescriptions.csv.gz Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz 2025-05-06 07:32:26,186 - pyhealth.datasets.base_dataset - INFO - Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz 2025-05-06 07:32:31,908 - pyhealth.datasets.base_dataset - INFO - Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz 2025-05-06 07:32:32,846 - pyhealth.datasets.base_dataset - INFO - Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz 2025-05-06 07:32:33,719 - pyhealth.datasets.base_dataset - INFO - Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz 2025-05-06 07:32:34,201 - pyhealth.datasets.base_dataset - INFO - Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz Memory usage After initializing mimic4_ehr: 843.0 MB 2025-05-06 07:32:35,070 - pyhealth.datasets.mimic4 - INFO - Memory usage After initializing mimic4_ehr: 843.0 MB Memory usage After EHR dataset initialization: 843.0 MB 2025-05-06 07:32:35,070 - pyhealth.datasets.mimic4 - INFO - Memory usage After EHR dataset initialization: 843.0 MB Memory usage Before combining data: 843.0 MB 2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Memory usage Before combining data: 843.0 MB Combining data from ehr dataset 2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Combining data from ehr dataset Creating combined dataframe 2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Creating combined dataframe Memory usage After combining data: 843.0 MB 2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Memory usage After combining data: 843.0 MB Memory usage Completed MIMIC4Dataset init: 843.0 MB 2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Memory usage Completed MIMIC4Dataset init: 843.0 MB Setting task InHospitalMortalityMIMIC4 for mimic4 base dataset... 2025-05-06 07:32:35,071 - pyhealth.datasets.base_dataset - INFO - Setting task InHospitalMortalityMIMIC4 for mimic4 base dataset... Collecting global event dataframe... 2025-05-06 07:32:35,071 - pyhealth.datasets.base_dataset - INFO - Collecting global event dataframe... Collected dataframe with shape: (131557, 44) 2025-05-06 07:32:35,400 - pyhealth.datasets.base_dataset - INFO - Collected dataframe with shape: (131557, 44) Generating samples with 2 worker(s)... 2025-05-06 07:32:35,401 - pyhealth.datasets.base_dataset - INFO - Generating samples with 2 worker(s)... Generating samples for InHospitalMortalityMIMIC4 2025-05-06 07:32:35,401 - pyhealth.datasets.base_dataset - INFO - Generating samples for InHospitalMortalityMIMIC4 Label mortality vocab: {0: 0, 1: 1} 2025-05-06 07:32:36,705 - pyhealth.processors.label_processor - INFO - Label mortality vocab: {0: 0, 1: 1} Processing samples: 100% 216/216 [00:00<00:00, 610.73it/s] Generated 216 samples for task InHospitalMortalityMIMIC4 2025-05-06 07:32:37,059 - pyhealth.datasets.base_dataset - INFO - Generated 216 samples for task InHospitalMortalityMIMIC4 2025-05-06 07:32:37,064 - __main__ - INFO - ===== Loaded 216 samples. ===== 2025-05-06 07:32:37,065 - __main__ - INFO - ===== Preprocessing samples (mean over time) ===== 2025-05-06 07:32:37,090 - __main__ - INFO - ===== Stage 1: Train the dummy classifier ===== WrappedClassifier( (model): DummyClassifier( (model): Sequential( (0): Linear(in_features=27, out_features=64, bias=True) (1): ReLU() (2): Linear(in_features=64, out_features=1, bias=True) ) ) ) 2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - WrappedClassifier( (model): DummyClassifier( (model): Sequential( (0): Linear(in_features=27, out_features=64, bias=True) (1): ReLU() (2): Linear(in_features=64, out_features=1, bias=True) ) ) ) Metrics: ['roc_auc', 'accuracy'] 2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Metrics: ['roc_auc', 'accuracy'] Device: cuda 2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Device: cuda2025-05-06 07:32:37,351 - pyhealth.trainer - INFO -
Training:
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Training:
Batch size: 32
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Optimizer params: {'lr': 0.001}
Weight decay: 0.0
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Weight decay: 0.0
Max grad norm: None
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b42355ee310>
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b42355ee310>
Monitor: roc_auc
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Monitor: roc_auc
Monitor criterion: max
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Monitor criterion: max
Epochs: 5
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Epochs: 5
2025-05-06 07:32:37,354 - pyhealth.trainer - INFO -
Epoch 0 / 5: 100% 5/5 [00:00<00:00, 7.42it/s]
--- Train epoch-0, step-5 ---
2025-05-06 07:32:38,028 - pyhealth.trainer - INFO - --- Train epoch-0, step-5 ---
loss: 0.6972
2025-05-06 07:32:38,028 - pyhealth.trainer - INFO - loss: 0.6972
Evaluation: 100% 1/1 [00:00<00:00, 792.13it/s]
--- Eval epoch-0, step-5 ---
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - --- Eval epoch-0, step-5 ---
roc_auc: 0.1500
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - roc_auc: 0.1500
accuracy: 0.6190
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - accuracy: 0.6190
loss: 0.6875
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - loss: 0.6875
New best roc_auc score (0.1500) at epoch-0, step-5
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - New best roc_auc score (0.1500) at epoch-0, step-5
2025-05-06 07:32:38,054 - pyhealth.trainer - INFO -
Epoch 1 / 5: 100% 5/5 [00:00<00:00, 450.82it/s]
--- Train epoch-1, step-10 ---
2025-05-06 07:32:38,066 - pyhealth.trainer - INFO - --- Train epoch-1, step-10 ---
loss: 0.6614
2025-05-06 07:32:38,066 - pyhealth.trainer - INFO - loss: 0.6614
Evaluation: 100% 1/1 [00:00<00:00, 973.38it/s]
--- Eval epoch-1, step-10 ---
2025-05-06 07:32:38,073 - pyhealth.trainer - INFO - --- Eval epoch-1, step-10 ---
roc_auc: 0.3500
2025-05-06 07:32:38,073 - pyhealth.trainer - INFO - roc_auc: 0.3500
accuracy: 0.9048
2025-05-06 07:32:38,074 - pyhealth.trainer - INFO - accuracy: 0.9048
loss: 0.6495
2025-05-06 07:32:38,074 - pyhealth.trainer - INFO - loss: 0.6495
New best roc_auc score (0.3500) at epoch-1, step-10
2025-05-06 07:32:38,074 - pyhealth.trainer - INFO - New best roc_auc score (0.3500) at epoch-1, step-10
2025-05-06 07:32:38,075 - pyhealth.trainer - INFO -
Epoch 2 / 5: 100% 5/5 [00:00<00:00, 404.94it/s]
--- Train epoch-2, step-15 ---
2025-05-06 07:32:38,088 - pyhealth.trainer - INFO - --- Train epoch-2, step-15 ---
loss: 0.6310
2025-05-06 07:32:38,088 - pyhealth.trainer - INFO - loss: 0.6310
Evaluation: 100% 1/1 [00:00<00:00, 872.54it/s]
--- Eval epoch-2, step-15 ---
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - --- Eval epoch-2, step-15 ---
roc_auc: 0.5500
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - roc_auc: 0.5500
accuracy: 0.9524
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 0.6134
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - loss: 0.6134
New best roc_auc score (0.5500) at epoch-2, step-15
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - New best roc_auc score (0.5500) at epoch-2, step-15
2025-05-06 07:32:38,095 - pyhealth.trainer - INFO -
Epoch 3 / 5: 100% 5/5 [00:00<00:00, 401.70it/s]
--- Train epoch-3, step-20 ---
2025-05-06 07:32:38,108 - pyhealth.trainer - INFO - --- Train epoch-3, step-20 ---
loss: 0.6013
2025-05-06 07:32:38,108 - pyhealth.trainer - INFO - loss: 0.6013
Evaluation: 100% 1/1 [00:00<00:00, 957.60it/s]
--- Eval epoch-3, step-20 ---
2025-05-06 07:32:38,113 - pyhealth.trainer - INFO - --- Eval epoch-3, step-20 ---
roc_auc: 0.6500
2025-05-06 07:32:38,114 - pyhealth.trainer - INFO - roc_auc: 0.6500
accuracy: 0.9524
2025-05-06 07:32:38,114 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 0.5788
2025-05-06 07:32:38,114 - pyhealth.trainer - INFO - loss: 0.5788
New best roc_auc score (0.6500) at epoch-3, step-20
2025-05-06 07:32:38,114 - pyhealth.trainer - INFO - New best roc_auc score (0.6500) at epoch-3, step-20
2025-05-06 07:32:38,115 - pyhealth.trainer - INFO -
Epoch 4 / 5: 100% 5/5 [00:00<00:00, 420.70it/s]
--- Train epoch-4, step-25 ---
2025-05-06 07:32:38,127 - pyhealth.trainer - INFO - --- Train epoch-4, step-25 ---
loss: 0.5689
2025-05-06 07:32:38,127 - pyhealth.trainer - INFO - loss: 0.5689
Evaluation: 100% 1/1 [00:00<00:00, 951.95it/s]
--- Eval epoch-4, step-25 ---
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - --- Eval epoch-4, step-25 ---
roc_auc: 0.7500
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - roc_auc: 0.7500
accuracy: 0.9524
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 0.5457
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - loss: 0.5457
New best roc_auc score (0.7500) at epoch-4, step-25
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - New best roc_auc score (0.7500) at epoch-4, step-25
Loaded best model
2025-05-06 07:32:38,134 - pyhealth.trainer - INFO - Loaded best model
2025-05-06 07:32:38,138 - main - INFO - ===== Freezing the classifier... =====
2025-05-06 07:32:38,138 - main - INFO - ===== Stage 2: Train CFVAE with frozen classifier =====
CFVAE(
(external_classifier): DummyClassifier(
(model): Sequential(
(0): Linear(in_features=27, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=1, bias=True)
)
)
(enc1): Sequential(
(0): Linear(in_features=27, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
)
(enc2): Linear(in_features=64, out_features=64, bias=True)
(dec1): Sequential(
(0): Linear(in_features=34, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
)
(dec2): Linear(in_features=64, out_features=27, bias=True)
)
2025-05-06 07:32:38,140 - pyhealth.trainer - INFO - CFVAE(
(external_classifier): DummyClassifier(
(model): Sequential(
(0): Linear(in_features=27, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=1, bias=True)
)
)
(enc1): Sequential(
(0): Linear(in_features=27, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
)
(enc2): Linear(in_features=64, out_features=64, bias=True)
(dec1): Sequential(
(0): Linear(in_features=34, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
)
(dec2): Linear(in_features=64, out_features=27, bias=True)
)
Metrics: ['roc_auc', 'accuracy']
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Metrics: ['roc_auc', 'accuracy']
Device: cuda
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Device: cuda
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO -
Training:
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Training:
Batch size: 32
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Optimizer params: {'lr': 0.001}
Weight decay: 0.0
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Weight decay: 0.0
Max grad norm: None
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b42355ee310>
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b42355ee310>
Monitor: roc_auc
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Monitor: roc_auc
Monitor criterion: max
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Monitor criterion: max
Epochs: 10
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Epochs: 10
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO -
Epoch 0 / 10: 100% 5/5 [00:00<00:00, 15.67it/s]
--- Train epoch-0, step-5 ---
2025-05-06 07:32:38,462 - pyhealth.trainer - INFO - --- Train epoch-0, step-5 ---
loss: 1.4165
2025-05-06 07:32:38,462 - pyhealth.trainer - INFO - loss: 1.4165
Evaluation: 100% 1/1 [00:00<00:00, 538.98it/s]
--- Eval epoch-0, step-5 ---
2025-05-06 07:32:38,468 - pyhealth.trainer - INFO - --- Eval epoch-0, step-5 ---
roc_auc: 0.0000
2025-05-06 07:32:38,468 - pyhealth.trainer - INFO - roc_auc: 0.0000
accuracy: 0.9524
2025-05-06 07:32:38,469 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.3753
2025-05-06 07:32:38,469 - pyhealth.trainer - INFO - loss: 1.3753
New best roc_auc score (0.0000) at epoch-0, step-5
2025-05-06 07:32:38,469 - pyhealth.trainer - INFO - New best roc_auc score (0.0000) at epoch-0, step-5
2025-05-06 07:32:38,471 - pyhealth.trainer - INFO -
Epoch 1 / 10: 100% 5/5 [00:00<00:00, 287.57it/s]
--- Train epoch-1, step-10 ---
2025-05-06 07:32:38,488 - pyhealth.trainer - INFO - --- Train epoch-1, step-10 ---
loss: 1.3389
2025-05-06 07:32:38,488 - pyhealth.trainer - INFO - loss: 1.3389
Evaluation: 100% 1/1 [00:00<00:00, 667.14it/s]
--- Eval epoch-1, step-10 ---
2025-05-06 07:32:38,493 - pyhealth.trainer - INFO - --- Eval epoch-1, step-10 ---
roc_auc: 0.5500
2025-05-06 07:32:38,494 - pyhealth.trainer - INFO - roc_auc: 0.5500
accuracy: 0.9524
2025-05-06 07:32:38,494 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.3141
2025-05-06 07:32:38,494 - pyhealth.trainer - INFO - loss: 1.3141
New best roc_auc score (0.5500) at epoch-1, step-10
2025-05-06 07:32:38,494 - pyhealth.trainer - INFO - New best roc_auc score (0.5500) at epoch-1, step-10
2025-05-06 07:32:38,496 - pyhealth.trainer - INFO -
Epoch 2 / 10: 100% 5/5 [00:00<00:00, 296.64it/s]
--- Train epoch-2, step-15 ---
2025-05-06 07:32:38,513 - pyhealth.trainer - INFO - --- Train epoch-2, step-15 ---
loss: 1.2860
2025-05-06 07:32:38,513 - pyhealth.trainer - INFO - loss: 1.2860
Evaluation: 100% 1/1 [00:00<00:00, 669.05it/s]
--- Eval epoch-2, step-15 ---
2025-05-06 07:32:38,518 - pyhealth.trainer - INFO - --- Eval epoch-2, step-15 ---
roc_auc: 0.5500
2025-05-06 07:32:38,518 - pyhealth.trainer - INFO - roc_auc: 0.5500
accuracy: 0.9524
2025-05-06 07:32:38,519 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2922
2025-05-06 07:32:38,519 - pyhealth.trainer - INFO - loss: 1.2922
2025-05-06 07:32:38,519 - pyhealth.trainer - INFO -
Epoch 3 / 10: 100% 5/5 [00:00<00:00, 304.42it/s]
--- Train epoch-3, step-20 ---
2025-05-06 07:32:38,536 - pyhealth.trainer - INFO - --- Train epoch-3, step-20 ---
loss: 1.2599
2025-05-06 07:32:38,536 - pyhealth.trainer - INFO - loss: 1.2599
Evaluation: 100% 1/1 [00:00<00:00, 619.27it/s]
--- Eval epoch-3, step-20 ---
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - --- Eval epoch-3, step-20 ---
roc_auc: 0.8500
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - roc_auc: 0.8500
accuracy: 0.9524
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2581
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - loss: 1.2581
New best roc_auc score (0.8500) at epoch-3, step-20
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - New best roc_auc score (0.8500) at epoch-3, step-20
2025-05-06 07:32:38,544 - pyhealth.trainer - INFO -
Epoch 4 / 10: 100% 5/5 [00:00<00:00, 295.69it/s]
--- Train epoch-4, step-25 ---
2025-05-06 07:32:38,561 - pyhealth.trainer - INFO - --- Train epoch-4, step-25 ---
loss: 1.2457
2025-05-06 07:32:38,561 - pyhealth.trainer - INFO - loss: 1.2457
Evaluation: 100% 1/1 [00:00<00:00, 648.07it/s]
--- Eval epoch-4, step-25 ---
2025-05-06 07:32:38,566 - pyhealth.trainer - INFO - --- Eval epoch-4, step-25 ---
roc_auc: 0.9500
2025-05-06 07:32:38,567 - pyhealth.trainer - INFO - roc_auc: 0.9500
accuracy: 0.9524
2025-05-06 07:32:38,567 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2414
2025-05-06 07:32:38,567 - pyhealth.trainer - INFO - loss: 1.2414
New best roc_auc score (0.9500) at epoch-4, step-25
2025-05-06 07:32:38,567 - pyhealth.trainer - INFO - New best roc_auc score (0.9500) at epoch-4, step-25
2025-05-06 07:32:38,569 - pyhealth.trainer - INFO -
Epoch 5 / 10: 100% 5/5 [00:00<00:00, 252.61it/s]
--- Train epoch-5, step-30 ---
2025-05-06 07:32:38,589 - pyhealth.trainer - INFO - --- Train epoch-5, step-30 ---
loss: 1.2208
2025-05-06 07:32:38,589 - pyhealth.trainer - INFO - loss: 1.2208
Evaluation: 100% 1/1 [00:00<00:00, 685.90it/s]
--- Eval epoch-5, step-30 ---
2025-05-06 07:32:38,594 - pyhealth.trainer - INFO - --- Eval epoch-5, step-30 ---
roc_auc: 0.3500
2025-05-06 07:32:38,594 - pyhealth.trainer - INFO - roc_auc: 0.3500
accuracy: 0.9524
2025-05-06 07:32:38,594 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2270
2025-05-06 07:32:38,594 - pyhealth.trainer - INFO - loss: 1.2270
2025-05-06 07:32:38,595 - pyhealth.trainer - INFO -
Epoch 6 / 10: 100% 5/5 [00:00<00:00, 294.96it/s]
--- Train epoch-6, step-35 ---
2025-05-06 07:32:38,612 - pyhealth.trainer - INFO - --- Train epoch-6, step-35 ---
loss: 1.1976
2025-05-06 07:32:38,612 - pyhealth.trainer - INFO - loss: 1.1976
Evaluation: 100% 1/1 [00:00<00:00, 660.42it/s]
--- Eval epoch-6, step-35 ---
2025-05-06 07:32:38,617 - pyhealth.trainer - INFO - --- Eval epoch-6, step-35 ---
roc_auc: 0.6000
2025-05-06 07:32:38,617 - pyhealth.trainer - INFO - roc_auc: 0.6000
accuracy: 0.9524
2025-05-06 07:32:38,617 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2117
2025-05-06 07:32:38,617 - pyhealth.trainer - INFO - loss: 1.2117
2025-05-06 07:32:38,618 - pyhealth.trainer - INFO -
Epoch 7 / 10: 100% 5/5 [00:00<00:00, 300.62it/s]
--- Train epoch-7, step-40 ---
2025-05-06 07:32:38,635 - pyhealth.trainer - INFO - --- Train epoch-7, step-40 ---
loss: 1.1940
2025-05-06 07:32:38,635 - pyhealth.trainer - INFO - loss: 1.1940
Evaluation: 100% 1/1 [00:00<00:00, 666.93it/s]
--- Eval epoch-7, step-40 ---
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO - --- Eval epoch-7, step-40 ---
roc_auc: 0.1000
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO - roc_auc: 0.1000
accuracy: 0.9524
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.1942
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO - loss: 1.1942
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO -
Epoch 8 / 10: 100% 5/5 [00:00<00:00, 296.75it/s]
--- Train epoch-8, step-45 ---
2025-05-06 07:32:38,657 - pyhealth.trainer - INFO - --- Train epoch-8, step-45 ---
loss: 1.1885
2025-05-06 07:32:38,658 - pyhealth.trainer - INFO - loss: 1.1885
Evaluation: 100% 1/1 [00:00<00:00, 641.43it/s]
--- Eval epoch-8, step-45 ---
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO - --- Eval epoch-8, step-45 ---
roc_auc: 0.5500
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO - roc_auc: 0.5500
accuracy: 0.9524
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.1871
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO - loss: 1.1871
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO -
Epoch 9 / 10: 100% 5/5 [00:00<00:00, 289.75it/s]
--- Train epoch-9, step-50 ---
2025-05-06 07:32:38,681 - pyhealth.trainer - INFO - --- Train epoch-9, step-50 ---
loss: 1.1748
2025-05-06 07:32:38,681 - pyhealth.trainer - INFO - loss: 1.1748
Evaluation: 100% 1/1 [00:00<00:00, 621.56it/s]
--- Eval epoch-9, step-50 ---
2025-05-06 07:32:38,686 - pyhealth.trainer - INFO - --- Eval epoch-9, step-50 ---
roc_auc: 0.6500
2025-05-06 07:32:38,686 - pyhealth.trainer - INFO - roc_auc: 0.6500
accuracy: 0.9524
2025-05-06 07:32:38,687 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.1849
2025-05-06 07:32:38,687 - pyhealth.trainer - INFO - loss: 1.1849
Loaded best model
2025-05-06 07:32:38,687 - pyhealth.trainer - INFO - Loaded best model
2025-05-06 07:32:38,691 - main - INFO - ===== Test set evaluation =====
Evaluation: 100% 2/2 [00:00<00:00, 72.16it/s]
{'roc_auc': np.float64(0.23577235772357724), 'accuracy': 0.9318181818181818, 'loss': 1.5149085521697998}
2025-05-06 07:32:38,722 - main - INFO - ===== Successfully completed CFVAE unit test! =====
Files to Review
pyhealth/models/cfvae.pypyhealth/unittests/test_cfvae_mortality_prediction.pyExtension
An extension is made to parametrize the external classifier to de-couple it as much as possible from the CFVAE itself. This deviates from the authors' original code to tightly couple their original MLP model with the CFVAE.
While the binary classifier needs to be part of the internal CFVAE layers, there are many types of binary prediction tasks in health care (mortality prediction, readmission prediction, etc.). The extra burden should not be placed on researchers to hand-craft each use-case, so we opt for more re-usability.
To do this, we enhance their described model to accept a frozen classifier to be passed as an argument to the CFVAE model, which also more closely aligns with the original paper's description of an external, black-box Binary Prediction (BP) model as seen in Figure 2 of the paper.
The PyHealth model in this PR outputs
loss,y_true, andy_probin alignment with PyHealth’s training pipeline expectations.Important Design Notes
vae.pymodel, which is designed for image signals and uses 2D convolutional blocks.