Skip to content

Conversation

@prunbun
Copy link

@prunbun prunbun commented May 7, 2025

Contributor Information

Contribution Summary

This pull request contributes the MedFuse model, a PyHealth-compatible architecture for multimodal learning by fusing Electronic Health Records (EHR) and Chest X-ray (CXR) data using an LSTM-based mechanism.

Detailed Description: MedFuse Model and Unit Tests

The MedFuse model is designed to integrate Electronic Health Records (EHR) and Chest X-ray (CXR) data for predictive tasks, adapted for the PyHealth library. This specific version implements an LSTM-based fusion mechanism.

Model Architecture Overview

The MedFuse model processes EHR and CXR data through distinct pathways before fusing their representations:

  • EHR Branch:
    • Accepts sequences of EHR codes (as integer indices).
    • Employs an embedding layer to convert codes into dense vectors.
    • Utilizes an LSTM (_EHR_LSTM_Core) to capture sequential patterns and extract EHR features.
  • CXR Branch:
    • Processes preprocessed CXR image tensors.
    • Leverages a pre-trained convolutional neural network (CNN) from torchvision.models (e.g., ResNet) as a backbone (_CXR_Core) to extract visual features. The original classifier of the backbone is replaced.
    • A linear projection layer standardizes the dimensionality of CXR features.
  • LSTM-Based Fusion & Classification:
    • The core of this implementation is its fusion strategy. It constructs a 2-step sequence input for a dedicated fusion LSTM:
      1. The first step combines EHR features with zero-vectors (representing the CXR modality, effectively an "EHR-only" view).
      2. The second step combines EHR features with the actual (or zeroed, if missing) projected CXR features.
    • This approach allows the model to explicitly handle samples where CXR data might be missing by adjusting the fusion sequence length and feature content.
    • The output from the fusion LSTM represents the integrated multimodal features.
    • A final linear classifier maps these fused features to the prediction task's output logits.

The model is designed to integrate with pyhealth.datasets.SampleDataset and expects specific data keys for EHR sequences, CXR images, EHR sequence lengths, and a boolean flag indicating the availability of a paired CXR image.

Unit Test: TestMedFuse

To ensure the robustness and correctness of the MedFuse model, a unit test (TestMedFuse) has been implemented:

  • Purpose: The primary goal is to verify that the model's forward() pass executes correctly and produces outputs of the expected shape under various conditions.
  • Methodology:
    • A UnitTestSampleDataset (a mock dataset) is used. This allows testing the model's internal logic without the overhead of a full data loading and preprocessing pipeline, by providing the necessary metadata (e.g., vocabulary sizes, output dimensions) directly.
    • The test initializes the MedFuse model with this mock dataset and predefined hyperparameters.
  • Scenarios Tested:
    • The test_medfuse_forward_pass method specifically focuses on the model's ability to handle different CXR data availability scenarios by varying the has_cxr boolean flag:
      • Mixed Availability: Some samples in the batch have CXR images, while others do not.
      • No CXR Data: All samples in the batch lack corresponding CXR images.
      • Full CXR Data: All samples in the batch have paired CXR images.
  • Validation: For each scenario, the test generates dummy input tensors, performs a forward pass through the model, and asserts that the output logits have the correct dimensions ([BATCH_SIZE, NUM_CLASSES]).

This unit test confirms that the model's architecture, particularly the LSTM fusion mechanism and its handling of potentially missing CXR data, operates as intended and produces structurally valid outputs.

Running Unit Tests

The unit tests for the MedFuse model can be executed from the root directory of the pyhealth library with the following command:

python -m pyhealth.models.medfuse

@linjc16 linjc16 added the Highlight for TAs to highlight label May 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Highlight for TAs to highlight

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants