- [April 5, 2025] We open source ECG-Bench for training and evaluating ELMs!
- Overview
- Installation
- ECG Datasets
- Main Methods
- Known Issues + Tips
- Contributions
- TODO
- Acknowledgements
- License
- Citations
This repository is a unified framework for training and evaluating electrocardiogram-language models (ELMs). The audience for this repository is mainly for researchers who are interested in developing ELMs, with a particular focus on ECG representations and training paradigms. The code is designed to be modular and flexible, allowing researchers to easily extend the framework to their own needs and quickly iterate on their ELM designs. Due to the intended audience and purpose of the repository, we try to provide the most basic and flexible code without many abstractions that can be easily extended. However, this goal is yet to be fully realized and we are continuously working to improve the codebase.
Currently, we are working on a benchmarking paper for ELMs and different ECG input representations / training paradigms. We will update the repository with the results and more information soon!
This current repository considers 4 input representations of ECGs as defined below:
ECG Signal:
The raw ECG signal is represented as a matrix X_sig
R^(C x L)
, where C
denotes the number of leads and L
is the number of time samples per lead. All other modalities are derived from X_sig
.
ECG Image:
An ECG image is derived from X_sig
via plotting and is represented as a tensor X_img
R^(H x W x C′)
, where H
and W
denote the image height and width, respectively, and C′
is the number of color channels.
Stacked ECG Signal:
We also create a synthetic three-channel version of X_sig
, denoted X_sig*
R^(C x L x 3)
, by stacking X_sig
three times along the color dimension (as seen in ECG Image).
ECG Text:
We use ECG-Byte’s compression schema to convert ECG signals into text. First, a normalized and discretized ECG signal X_sig
is mapped to a symbolic sequence using a set of symbols A = {a, b, …, z}
. This sequence is then flattened into a one-dimensional array X_symb
A^(C * L)
. Finally, a byte-pair encoding (BPE) process compresses X_symb
into a sequence of tokens from an extended vocabulary V, resulting in the final textual representation X_ID
m
is the length of the token sequence.
We consider 2 broadly defined training paradigms for ELMs in this repository:
- 2-Stage Training (can also be seen as Encoder methods)
- End-to-End Training (can also be seen as Encoder-Free methods)
However, we further break down 2-Stage Training into 4 sub-methods:
- 2-Stage Scratch
- 2-Stage LLaVA
- 2-Stage Finetune
- 2-Stage End-to-End
2-Stage Scratch
In this approach, we train an ECG-specific encoder f_ECG: R^(C x L) -> R^d
using self-supervised learning (SSL) on the raw ECG signal X_sig
. Here, C
represents the number of channels, L
the number of time steps, and d
the dimension of the latent space where ECG data is encoded.
SSL approaches, such as masked image modeling or contrastive learning, are employed. In contrastive learning, a text encoder f_text: R -> R^d
may map textual reports to the same d
-dimensional latent space for alignment with ECG encodings.
2-Stage LLaVA
In this approach, we utilize general pretrained image encoders, such as CLIP or ViT, as f_ECG
. Since these encoders expect image inputs, the ECG data must be adapted: either by creating a synthetic three-channel ECG signal X_sig*
, or by using an image representation X_img
of the ECG.
In the LLaVA-style approach, f_ECG
is frozen, and a learnable projection matrix W
R^(h x d)
is introduced, where h
is the hidden dimension of the LLM. During the second stage, the latent vector z = f_ECG(X)
is projected to z′ = W z
, concatenated with the embedded query Q
, and fed into the LLM to generate the response S
. Only W
and the LLM are trained, while f_ECG
remains fixed.
2-Stage Finetune
Another approach finetunes the general, pretrained image encoder f_ECG
on either X_sig*
or X_img
before the second stage.
2-Stage End-to-End
In this approach, we train the LLM and f_ECG
jointly with only an autoregressive objective. We find this approach to be not that effective, but some previous works have used it.
NOTE: In all 2-stage approaches, the second stage trains the LLM with an autoregressive objective. The latent vector z
from f_ECG
is projected via W
to z′
, concatenated with the embedded query Q
, and input to the LLM to generate the response S
. Note that W
is also trained for all 2-stage approaches.
End-to-End Training
For the End-to-End training setting, the ECG signal X_sig
is transformed to tokens X_ID
(similar to text) using methods from ECG-Byte. Therefore, one can directly train the LLM for autoregressive generation since both X_ID
and text are tokenized.
We also provide preprocessing pipelines for various datasets in this repository.
Datasets:
- PTB-XL, a large publicly available electrocardiography dataset
- MIMIC-IV-ECG: Diagnostic Electrocardiogram Matched Subset
- CODE-15%: a large scale annotated dataset of 12-lead ECGs
- CPSC from Classification of 12-lead ECGs: The PhysioNet/Computing in Cardiology Challenge 2020
- CSN from A large scale 12-lead electrocardiogram database for arrhythmia study
- MIMIC-IV and PTB-XL variants of ECG-QA: A Comprehensive Question Answering Dataset Combined With Electrocardiogram
- Pretrain MIMIC-IV and ECG Instruct 45K from ECG-Chat: A Large ECG-Language Model for Cardiac Disease Diagnosis
- ECG Instruct Pulse and ECG Bench Pulse from Teach Multimodal LLMs to Comprehend Electrocardiographic Images
- ECG Grounding Datasets from GEM: Empowering MLLM for Grounded ECG Understanding with Time Series and Images
We implement the following ELMs:
We also provide implementations of the following ECG-specific encoders:
- Guiding Masked Representation Learning to Capture Spatio-Temporal Relationship of Electrocardiogram
- Zero-Shot ECG Classification with Multimodal Learning and Test-time Clinical Knowledge Enhancement
- MaeFE: Masked Autoencoders Family of Electrocardiogram for Self-Supervised Pretraining and Transfer Learning
Utilizing HuggingFace, we also provide general, pretrained models to serve as ECG encoders:
We utilize the HuggingFace API to create wrappers around the following pretrained LLMs:
We also have GPT 2 and OPT LLMs, however, we do not have chat tempaltes for them yet.
We provide the following features for training and evaluating ELMs:
- Single and distributed training.
- We impemented an LLM judge with llm-blender and utilized DPO for post-training.
- Flash Attention 2 for faster training and inference.
- A demo based on gradio for chatting with your own trained ELM and collect preference data.
We hope to continouously update the repository to support more features, ELMs, and datasets. Please feel free to contribute to the repository! Please carefully read the below documentations to run the pipeline. If there are any questions or bugs, please do not hesitate to reach out to wjhan{@}andrew{dot}cmu{edu} or submit an issue with corresponding details.
All installations and experiments were completed on Ubuntu 20.04.5 LTS with NVIDIA A5000 and A6000 GPUs.
-
To install Rust:
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain=1.79.0 -y
-
Open a new terminal to set PATH for Rust installation.
-
After opening a new terminal, check the Rust installation by running
rustc --version
. -
Create the conda virtual environment via
conda create -n ecg python=3.10.15
. -
Activate the environment
conda activate ecg
-
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
(make sure when executingnvcc --version
you get version 12.1) -
git clone https://github.com/willxxy/ECG-Bench.git
-
cd ECG-Bench
-
git submodule init
-
git submodule update
-
Please
cd
into theECG-Bench/transformers
directory andpip install -e .
. -
Now
cd ../
andcd
into theECG-Bench/ecg-plot
directory andpip install -e .
. -
Now
cd ../
andpip install -e .
-
To install Flash Attention 2 please use the following command:
pip cache remove flash_attn
pip install flash-attn==2.7.4.post1 --no-cache-dir
-
To install the
llm-blender
andtrl[judges]
packages please run the following commands:pip install git+https://github.com/yuchenlin/LLM-Blender.git
pip install trl[judges]
-
cd
intoECG-Bench/ecg_bench/rust_bpe
and executematurin develop --release
to compile the tokenizer. -
Run all the tests by executing
python tests/run_all_tests.py
. -
Another consideration is that we use gated models (e.g., Llama 3.2, Gemma) from HuggingFace, therefore you will need to get an api key and log into it via
huggingface-cli login
in the terminal. We also require you to log in inside the main training *.py file via the login functionfrom huggingface_hub import login
.
NOTE: From now, all instructions will assume you are working from the ECG-Bench/ecg_bench
directory.
We regard base datasets as datasets that are solely used for later mapping of external datasets.
-
Please download the PTB-XL dataset through this link.
-
Please create a
data
folder, unzip the zip file inside thedata
folder and rename the folder asptb
.
-
Please download the Mimic IV ECG dataset through this link.
-
Unzip the zip file inside the
data
directory and rename the unzipped directory asmimic
.
-
First create a
code15
folder inside thedata
directory. -
Then inside
data/code15
execute the following bash script to download the data and unzip it:
#!/bin/bash
for i in {0..17}; do
echo "Downloading part ${i}..."
wget -O "exams_part${i}.zip" "https://zenodo.org/records/4916206/files/exams_part${i}.zip?download=1"
if [ $? -eq 0 ]; then
echo "Successfully downloaded part ${i}"
echo "Extracting part ${i}..."
unzip -q "exams_part${i}.zip"
if [ $? -eq 0 ]; then
echo "Successfully extracted part ${i}"
rm "exams_part${i}.zip"
else
echo "Error extracting part ${i}"
fi
else
echo "Error downloading part ${i}"
fi
done
echo "All downloads and extractions completed"
-
Create a
csn
folder inside thedata
directory. -
Inside
data/csn
execute the following command in the terminal:
wget https://physionet.org/static/published-projects/ecg-arrhythmia/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0.zip
- Unzip the file and inside of
data/csn/a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0
move all of the contents outside todata/csn
. Then you may delete thea-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0
folder.
-
Create a
cpsc
folder inside thedata
directory. -
Inside
data/cpsc
execute the following command in the terminal:
wget https://physionet.org/static/published-projects/challenge-2020/classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2.zip
- Unzip the file and inside of
data/cpsc/classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2/training
move thecpsc_2018
andcpsc_2018_extra
folders into thedata/cpsc
directory. Then delete theclassification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2
folder.
Mapping datasets are datasets that are mapped to the base datasets and subsequently used for all experiments.
ECG-QA dataset curated by ECG-QA, Oh et al.
- To download the ECG-QA dataset, please execute the following command in the
data
folder:
git clone https://github.com/Jwoo5/ecg-qa.git
-
We exactly follow the instructions in this section of the repository for mapping the PTB-XL and MIMIC IV ECG dataset to the question and answers.
cd
into ecg-qa and execute the following commands in the terminal to prepare the ECG-QA dataset. -
To map the ECG-QA dataset to mimic and ptb
cd
inside thedata/ecg-qa
directory and execute the following scripts respectively.
python mapping_ptbxl_samples.py ecgqa/ptbxl \
--ptbxl-data-dir ../ptb
python mapping_mimic_iv_ecg_samples.py ecgqa/mimic-iv-ecg \
--mimic-iv-ecg-data-dir ../mimic
- After mapping the datasets, you should have an output folder in the
data/ecg-qa
folder with the mappedparaphrased
andtemplate
question and answers.
Pretrain MIMIC dataset curated by ECG-Chat, Zhao et al.
- Next create a
data/pretrain_mimic
directory and download thepretrain_mimic.json
file from this dropbox link.
Instruct 45k MIMIC dataset curated by ECG-Chat, Zhao et al.
- Next create a
data/ecg_instruct_45k
directory and download theecg_instruct_45k.json
file from this link.
ECG Instruct Pulse dataset curated by PULSE, Liu et al.
- Create a 'data/ecg_instruct_pulse' directory and downlod the
ECGInstruct.json
from this link. Then rename it toecg_instruct_pulse.json
.
ECG Bench Pulse dataset curated by PULSE, Liu et al.
- The ECG Bench Pulse dataset is exclusively on HuggingFace with
.parquet
files, therefore, we utilize thedatasets
library directly to download the dataset. All you have to do is simply definemap_data
in the preprocess script asecg_bench_pulse
.
ECG Grounding Datasets curated by GEM, Lan et al.
- Create a
data/ecg_grounding
directory and download theECG_Grounding_30k.json
,ecg-grounding-test.json
andgrounding_train_30k.json
from this link. A quick note is thatgrounding_train_30k.json
is a subset ofECG_Grounding_30k.json
, whereECG_Grounding_30k.json
contains all 30k ECG grounding samples found ingrounding_train_30k.json
, with additional ECG conversational data from the ECG Instruct PULSE dataset.
- Execute the preprocessing script by
bash scripts/preprocess.sh
. We have provided default configurations for all the datasets used in our study but feel free to experiment with others! We provide some example configurations for Base, Mapping, and RAG dataset curation.
Example configurations for Base dataset curation:
python preprocess_ecg.py \
--base_data=$base_data \
--seg_len=$seg_len \
--preprocess_files
where $base_data
is one of ptb
, mimic
, code15
, csn
, or cpsc
. $seg_len
is the segment length you want to use for training.
Example configurations for Mapping dataset curation:
python preprocess_ecg.py \
--map_data=$map_data \
--seg_len=$seg_len
where $map_data
is one of ecg_instruct_45k
, pretrain_mimic
, ecg_instruct_pulse
, ecg_bench_pulse
, ecg-qa_mimic-iv-ecg
, ecg-qa_ptbxl
, ecg_grounding_pulse
, ecg_grounding
, or ecg_grounding_test
.
You can also mix multiple datasets together by defining mix_data
in the preprocess script. Here is an example configuration:
python preprocess_ecg.py \
--mix_data=$mix_data
where $mix_data
can be ecg_instruct_45k_mapped_1250,ecg_bench_pulse_mapped_1250
.
For RAG dataset curation, an example configuration looks like so:
python preprocess_ecg.py \
--base_data=$base_data \
--seg_len=$seg_len \
--create_rag_db
We encourage you to use the mimic
base dataset for RAG curation since it has the most amount of ECGs.
After running the RAG dataset curation, you should have a data/$base_data/combined.index
file and a data/$base_data/rag_metadata.json
file.
You can then load in the RAG database and test it out by running the following command:
python preprocess_ecg.py \
--base_data=$base_data \
--seg_len=$seg_len \
--load_rag_db=./data/$base_data/rag_metadata.json \
--load_rag_db_idx=./data/$base_data/combined.index
Lastly, for sampling ECGs for training ECG-Byte and getting percentiles, an example configuration looks like so:
python preprocess_ecg.py \
--base_data=$base_data \
--seg_len=$seg_len \
--sample_files \
--$sampling_method \
--sample_percentiles
where $sampling_method
is either random_sampling
or stratified_sampling
. ECG-Byte utilizes stratified sampling, however, you can use random sampling as well.
We provide the script for training the first stage in 2-stage scratch and 2-stage finetune in scripts/train_1st.sh
. Single GPU training looks like so:
python main.py \
--data=$data \
--model=$encoder \
--device=cuda:2 \
--train=first \
--batch_size=64 \
--seg_len=1250 \
--epochs=50 \
--instance_normalize \
--attn_implementation=flash_attention_2 \
--log
For multi-GPU training, it looks like so:
python main.py \
--data=mimic-iv-ecg_mapped_1250 \
--model=$encoder \
--dis \
--gpus=1,2,3,4 \
--train=first \
--batch_size=64 \
--seg_len=1250 \
--epochs=50 \
--instance_normalize \
--attn_implementation=flash_attention_2 \
--log
After training the first stage, you can train the second stage by running scripts/train_2nd.sh
by defining the encoder checkpoint like so:
python main.py \
--data=$data \
--model=$encoder_$llm \
--dis \
--gpus=1,2,3,4 \
--train=second \
--batch_size=64 \
--seg_len=1250 \
--epochs=50 \
--instance_normalize \
--system_prompt=$system_prompt.txt \
--attn_implementation=flash_attention_2 \
--encoder_checkpoint=$encoder_checkpoint \
--log
For 2-stage LLaVA, we provide the script for training in scripts/train_2nd.sh
. As LLaVA directly utilizes the pretrained, general encoder and only updates the projection head, utilize either CLIP, ViT, or SIGLIP for the encoder and do not pass in the encoder checkpoint.
For single GPU training, it looks like so:
python main.py \
--data=$data \
--model=$encoder_ \
--device=cuda:2 \
--train=second \
--batch_size=64 \
--seg_len=1250 \
--epochs=50 \
--attn_implementation=flash_attention_2 \
--system_prompt=$system_prompt.txt \
--log
For multi-GPU training, it looks like so:
python main.py \
--data=$data \
--model=$encoder_$llm \
--dis \
--gpus=1,2,3,4 \
--train=second \
--batch_size=64 \
--seg_len=1250 \
--epochs=50 \
--instance_normalize \
--system_prompt=$system_prompt.txt \
--attn_implementation=flash_attention_2 \
--log
If you want to utilize the image modality (plot of ECG), you can add the following argument:
python main.py \
--data=$data \
--model=$encoder_llm \
--device=cuda:2 \
--train=second \
--batch_size=64 \
--seg_len=1250 \
--epochs=50 \
--attn_implementation=flash_attention_2 \
--image \
--system_prompt=$system_prompt.txt \
--log
For image augmentation, you can add the following argument:
python main.py \
--data=$data \
--model=$encoder_llm \
--device=cuda:2 \
--train=second \
--batch_size=64 \
--seg_len=1250 \
--epochs=50 \
--attn_implementation=flash_attention_2 \
--image \
--augment_image \
--system_prompt=$system_prompt.txt \
--log
-
For non-image and text representations of ECGs (e.g., ECG signal and stacked ECG signal), the representation between ECG signal and stacked ECG signal is automatically allocated when using a particular ECG encoder. For general, pretrained encoders, the stacked ECG signal representation is used due to the input size requirement. For ECG-specific encoders, the original ECG signal representation is used.
-
For image representations of ECGs (e.g., ECG image), the image representation is automatically plotted using the
ecg-plot
package. -
For any 2-stage method, if you want to fully finetune the encoder during the second stage (i.e., 2-stage End-to-End) with only the autoregressive objective, you can add the following argument:
python main.py \
--data=$data \
--model=$encoder_llm \
--device=cuda:2 \
--train=second \
--batch_size=64 \
--seg_len=1250 \
--epochs=50 \
--attn_implementation=flash_attention_2 \
--system_prompt=$system_prompt.txt \
--train_encoder \
--log
We provide the scripts for inferencing each type of 2-stage training method in scripts/inference_2stage.sh
. For 2-stage finetune and 2-stage scratch, make sure to define the encoder checkpoint.
Example of 2-stage finetune or scratch:
python main.py \
--data=$data \
--model=$encoder_llm \
--device=cuda:7 \
--peft \
--inference=second \
--checkpoint=$checkpoint \
--system_prompt=$system_prompt.txt \
--encoder_checkpoint=$encoder_checkpoint
Example of 2-stage LLaVA:
python main.py \
--data=$data \
--model=$encoder_llm \
--device=cuda:7 \
--peft \
--inference=second \
--checkpoint=$checkpoint \
--system_prompt=$system_prompt.txt
Make sure to add the necessary arguments for your particular use case.
-
During preprocessing, there is a sampling stage where we sample N number utilizing one of two techniques. The techniques are random sampling or morphological clustering based sampling. We found that random sampling is enough for our use case.
-
After sampling, a sampled file .txt file should pop up under the data folder.These sampled files will be the ECGs considered during training of ECG-Byte.
-
To train ECG-Byte, simply execute
sh scripts/train_tokenizer.sh
. We provide the default configurations utilized in the paper but feel free to change it! Here is an example of training ECG-Byte:
python train_tokenizer.py \
--num_merges=$num_merges \
--sampled_files=$sampled_files.txt \
--num_processes=$num_processes \
--train
To load in a pre-trained ECG-Byte tokenizer and verify it, you can add the following argument:
python train_tokenizer.py \
--num_merges=$num_merges \
--sampled_files=$sampled_files.txt \
--num_processes=$num_processes \
--ecg_tokenizer=$ecg_tokenizer
For training End-to-End, we provide the script in scripts/train_end2end.sh
. We provide the basic configurations in the file but feel free to modify it. Here is an example of training End-to-End:
python main.py \
--data=$data \
--model=$llm \
--device=cuda:5 \
--ecg_tokenizer=$ecg_tokenizer \
--seg_len=1250 \
--peft \
--train=end2end \
--system_prompt=$system_prompt.txt \
--batch_size=8 \
--pad_to_max=1024 \
--epochs=1 \
--attn_implementation=flash_attention_2 \
--log
For inferencing End-to-End, we provide the script in scripts/inference_end2end.sh
. We provide the basic configurations in the file but feel free to modify it. Here is an example of inferencing End-to-End:
python main.py \
--data=$data \
--model=$llm \
--device=cuda:7 \
--ecg_tokenizer=$ecg_tokenizer \
--peft \
--inference=end2end \
--checkpoint=$checkpoint \
--system_prompt=$system_prompt.txt \
--attn_implementation=flash_attention_2 \
--batch_size=1
For all inferencing results, we inference over multiple seeds and provide statistical results over the multiple seeds (e.g., 95% confidence intervals, standard deviation, etc.). To organize the results and have them printed out in a nice format, we provide the script in scripts/org_results.sh
. We provide the basic configurations in the file but feel free to modify it. Here is an example of organizing results:
python organize_results.py \
--checkpoint=$checkpoint
For RAG inferencing, we provide the script in scripts/inference_end2end_rag.sh
. We provide the basic configurations in the file but feel free to modify it. Here is an example of inferencing End-to-End with RAG:
python main.py \
--data=$d \
--model=$llm \
--device=cuda:3 \
--ecg_tokenizer=$ecg_tokenizer \
--seg_len=1250 \
--peft \
--inference=$inference_method \
--checkpoint=$checkpoint \
--system_prompt=$system_prompt.txt \
--batch_size=1 \
--pad_to_max=1024 \
--instance_normalize \
--attn_implementation=flash_attention_2 \
--rag \
--rag_k=$k \
--load_rag_db=$rag_db.json \
--load_rag_db_idx=$rag_db_idx.index
where $k
is the number of retrieved ECGs and $rag_db.json
and $rag_db_idx.index
are the RAG database and index, respectively.
Although its unconventional, you can also train with RAG by adding the same arguments during training:
python main.py \
--data=$data \
--model=$llm \
--device=cuda:5 \
--ecg_tokenizer=$ecg_tokenizer \
--seg_len=1250 \
--peft \
--train=$train_method \
--system_prompt=$system_prompt.txt \
--batch_size=8 \
--pad_to_max=1024 \
--epochs=1 \
--attn_implementation=flash_attention_2 \
--log \
--rag \
--rag_k=$k \
--load_rag_db=$rag_db.json \
--load_rag_db_idx=$rag_db_idx.index
We provide a demo for chatting with your own trained ELM! To run the demo, please execute the script in scripts/run_demo.sh
. For the demo, it is the same command as the inference script but utilizing the demo.py
file. Currently, the demo is only supporting End-to-End methods.
python demo.py \
--data=$data \
--model=$llm \
--device=cuda:7 \
--ecg_tokenizer=$ecg_tokenizer \
--peft \
--inference=end2end \
--checkpoint=$checkpoint \
--system_prompt=$system_prompt.txt \
--attn_implementation=flash_attention_2 \
--batch_size=1
We provide attention visualizations and tokenization analysis scripts taken from ECG-Byte. Please view the README in the official ECG-Byte repository and the scripts scripts/token_dist.sh
and scripts/track_encode.sh
for more details.
We encountered some issues during development of ECG-Bench (mostly taken from ECG-Byte) and hope to contribute to the open source community by reporting them here and adding any tips if possible. If you happen to know a good solution to any of them, please do not hesitate to open an issue or pull request!
-
tqdm
bar freezing script with multiprocessing - We noticed that the tqdm bar freezes sometimes when we put it inside a multiprocessing job (especially during preprocessing). We recommend adding print statements before and after the main operations inside the tqdm loop to ensure the operations are being executed. This is a thread of the issue from the tqdm repository. Please feel free to look at it! -
Utilizing inputs_embeds for generation - We noticed that utilizing inputs_embeds as the primary input to the model for generation is quite instable (e.g., example1 from HF, example2 from stackoverflow, example3 from vllm but related, example4 from HF). When we tried generating via only
inputs_embeds
the model failed to generate anything coherent (i.e., mostly empty strings). Our current workaround is passing in bothinput_ids
andinputs_embeds
as inputs for generation. The reasoning behind this is from the GenerationMixin code and this thread. From the code, it seems like the model creates an empty input_ids tensor of shape (batch_size, 0) and uses the embeddings only for the first forward pass. However, this can be unstable because there's no explicit token mapping for the embeddings, making it harder for the model to maintain coherence between the embedded representation and subsequent token generation. The solution for this would be to create betterinputs_embeds
from the getgo. However, we wanted to add some guidance to the generation therefore we provided embeddings for the initial forward pass while having input_ids that explicitly map to those embeddings, providing a more stable foundation for generation. This is not "true" generation only usinginputs_embeds
, therefore we believe that this reinforces our method of representing ECGs even more. -
HuggingFace api key not being recognized - We also noticed that the main training script sometimes crashes due to the huggingface api key not being recognized. The current workaround is just to relogin utilizing your own personal api key.
-
Nan values during preprocessing - We noticed that the MIMIC-IV ECG dataset has many nan values during preprocessing so we workaround this by skipping them.
-
Crash during ECG sampling - When sampling ECGs using morphological clustering during preprocessing, we currently have the following configurations for the number of threads:
os.environ['OPENBLAS_NUM_THREADS'] = '4' os.environ['MKL_NUM_THREADS'] = '4' os.environ['VECLIB_MAXIMUM_THREADS'] = '4' os.environ['NUMEXPR_NUM_THREADS'] = '4'
We noticed that on some machines under computational constraints this number is too high when largely launching the PCA analysis, thus resulting in a crash. In this case, simply reduce the maximum number of threads for each os.environ to either 1 or 2. Reducing this number should solve the problem, however, if you continue to run into crashes please feel free to report an issue!
We welcome contributions to the repository! Please feel free to open an issue or pull request for any bugs or features you would like to add. We are always looking for new ECG datasets to benchmark our methods on. If you have any recommendations, please let us know! Also, a good place to start is by looking at the TODO section.
For most processes, we have a --dev
flag to run in a smaller scale and add some verbosity for debugging. Feel free to add this flag when needed!
We thank the following people for their contributions to the repository:
This is a list of TODOs for the repository. If you are interested in contributing, please feel free to look at the list and open a PR! We are always looking for ways to add more documentation, examples, tests, and workflows for the codebase. Lastly, general improvements to the codebase are always welcome!
- Add default chat templates for LLMs without chat templates (e.g., GPT 2, OPT).
- Add GEM model
- Add ECG-Expert-QA dataset
- Add ECG-Grounding Dataset
- Provide HuggingFace dataset and model card push ability.
- Create an offline demo for ELMs with unified preference collection.
- Retrieval-Augmented Generation
- Make RAG searching faster.
- Make training with RAG faster.
- Add encoder-free VLMs such as Fuyu-8B, Vision as LoRA, and/or Unveiling Encoder-Free Vision-Language Models for ECGs. This could be extended for all training methods.
- Addition for new input representation: ECG features
- Reasoning ability for ELMs (akin to OpenAI o1, Deepseek R1, etc.).
- Curate higher quality instruction tuning and reasoning datasets for ELMs.
- Expand upon current naive distributed training setting to include more efficient and explicit distributed training strategies (i.e., data, tensor, context, pipeline, and expert parallelism as noted in here).
- Add option for data mixing.
This work is done in collaboration with the Mario Lemieux Center for Heart Rhythm Care at Allegheny General Hospital.
We thank Chaojing Duan, Michael A. Rosenberg, Emerson Liu, Ding Zhao, Hyoeun Kang, Wenhao Ding, Haohong Lin, Shiqi Liu, Xiaoyu (Simon) Song, Atharva Mhaskar, Zhepeng Cen, Yihang Yao, and Dylan Leong for their helpful discussions, feedbacks, and support in developing ECG-Bench.
We thank the authors of ECG-Byte, MERL, ST-MEM, ECG-QA, ECG-Chat, PULSE, and GEM for their code and publicly released datasets.
Lastly, we thank HuggingFace for providing the APIs for the models.
This repository contains code licensed under the MIT License, except for the following .py
files in the ecg_bench/models/encoder
directory: st_mem.py
, mlae.py
, mtae.py
. These files are licensed under the Creative Commons Attribution-NonCommercial 4.0 International License. Please view the original license in their respective repository for more details.
If this codebase or work has helped you please cite the following:
@misc{han2024ecgbytetokenizerendtoendgenerative,
title={ECG-Byte: A Tokenizer for End-to-End Generative Electrocardiogram Language Modeling},
author={William Han and Chaojing Duan and Michael A. Rosenberg and Emerson Liu and Ding Zhao},
year={2024},
eprint={2412.14373},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2412.14373},
}