This repository provides the code for MedViLL(Medical Vision Language Learner).
Our proposed architecture MedViLL is a single BERT-based model that learns unified contextualized vision-language (VL) representation for both Vision Language Understanding (VLU) and Vision Language Generation (VLG). MedViLL performs pre-training with a CNN-based visual encoder and a cross-modal Transformer for VL joint representation learning. After pre-training, our model can be easily used for VLU and VLG tasks with task-specific finetuning. Please refer to our paper "Multi-modal Understanding and Generation for Medical Images and Text via Vision-Language Pre-Training" for more details.
We provide five versions of BERT-based pre-trained weights with different types of self-attention masks. Pre-training for the joint embedding was built on the BERT-base architecutre(12 hidden layers, 12 attention heads, 768 hidden size), and training details are described in our paper. Currently avaliable versions of pre-trained weights are as follows:
-
MedViLL - BERT-Base model with Bidirectional Auto-regressive attention mask.
-
Bi & Seq2Seq - BERT-Base model with Seq2Seq attention mask(75%) and Bidirectional attention mask(25%) in every mini-batch.
-
Bidirectional - BERT-Base model with Bidirectional attention mask.
-
Seq2Seq - BERT-Base model with Seq2Seq attention mask.
-
Non-cross - BERT-Base model with Non-cross modality attention mask.
We provide a pre-processed version of multiple datasets for each task as follows:
Download each dataset to the path /data/[dataset].
- MIMIC-CXR (2.27 GB): Unique study of 91,685 AP view image and associated report pairs.
- OPEN-I (74.1 MB): Unique study of 3,547 AP and PA image-report pairs from the official Open-I dataset.
- VQA-RAD (402 MB): 3,515 question answer pairs on 315 images (104 head CTs or MRIs, 107 Chest X-rays, and 104 abdominal CTs).
We also provide the JSON file with the path for validation in the retrieval task, download each files to the path /data/[dataset]. Image to report retrieval
- MIMIC valid, 2) MIMIC test, 3) OpenI test
Report to Image retrieval
- MIMIC valid, 2) MIMIC test, 3) OpenI test
Sections below describe the virtual env installation and the fine-training process of MedviLL based on pytorch version 1.7, python version 3.8. To fine-tune MedViLL, you need to download the pre-trained weights of MedViLL. After downloading the pre-trained weights, use medvill.yaml to install conda based virtual env as follows:
$ git clone https://github.com/SuperSupermoon/MedViLL.git
$ cd MedViLL; conda env create --file medvill.yaml
Note that all fine-tuning models were conducted on 8 Geforce RTX-3090 GPU machines, each of which has 24GB of VRAM.
Unzip mimic, openi, and VQA-RAD tar.gz files.
$ cd MedViLL; tar -zxvf [file_name.tar.gz]
Example:
$ cd MedViLL
$ python main.py
- Diagnosis Classification Example:
$ cd MedViLL/downstream_task/classification
$ python cls.py
- Image-Report Retrieval Example:
$ cd MedViLL/downstream_task/retrieval
$ python retrieval.py
- Medical Visual Qestion Answering Example:
$ python -m torch.distributed.launch --nproc_per_node=1 --master_port 9872 --use_env downstream_task/report_generation_and_vqa/finetune.py --model_recover_path pt_model/model.50.bin --tasks vqa --s2s_prob 0 --bi_prob 1 --mask_prob 0 --vqa_rad chest --vqa_eval
- Report Generation Example:
$ cd MedViLL/downstream_task/report_generation_and_vqa
$ python finetune.py --tasks report_generation --mask_prob 0.15 --s2s_prob 1 --bi_prob 0