Contributing MedFuse model #472
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Contributor Information
Contribution Summary
This pull request contributes the
MedFusemodel, a PyHealth-compatible architecture for multimodal learning by fusing Electronic Health Records (EHR) and Chest X-ray (CXR) data using an LSTM-based mechanism.Detailed Description: MedFuse Model and Unit Tests
The
MedFusemodel is designed to integrate Electronic Health Records (EHR) and Chest X-ray (CXR) data for predictive tasks, adapted for the PyHealth library. This specific version implements an LSTM-based fusion mechanism.Model Architecture Overview
The
MedFusemodel processes EHR and CXR data through distinct pathways before fusing their representations:_EHR_LSTM_Core) to capture sequential patterns and extract EHR features.torchvision.models(e.g., ResNet) as a backbone (_CXR_Core) to extract visual features. The original classifier of the backbone is replaced.The model is designed to integrate with
pyhealth.datasets.SampleDatasetand expects specific data keys for EHR sequences, CXR images, EHR sequence lengths, and a boolean flag indicating the availability of a paired CXR image.Unit Test:
TestMedFuseTo ensure the robustness and correctness of the
MedFusemodel, a unit test (TestMedFuse) has been implemented:forward()pass executes correctly and produces outputs of the expected shape under various conditions.UnitTestSampleDataset(a mock dataset) is used. This allows testing the model's internal logic without the overhead of a full data loading and preprocessing pipeline, by providing the necessary metadata (e.g., vocabulary sizes, output dimensions) directly.MedFusemodel with this mock dataset and predefined hyperparameters.test_medfuse_forward_passmethod specifically focuses on the model's ability to handle different CXR data availability scenarios by varying thehas_cxrboolean flag:[BATCH_SIZE, NUM_CLASSES]).This unit test confirms that the model's architecture, particularly the LSTM fusion mechanism and its handling of potentially missing CXR data, operates as intended and produces structurally valid outputs.
Running Unit Tests
The unit tests for the
MedFusemodel can be executed from the root directory of thepyhealthlibrary with the following command: