Diaformer is an efficient model for automatic diagnosis via symptoms sequence generation. It takes the sequence of symptoms as input, and predicts the inquiry symptoms in the way of sequence generation.
Figure 1: Illustration of symptom attention framework.
Our experiments are conducted on Python 3.8 and Pytorch == 1.8.0. The main requirements are:
- transformers==2.1.1
- torch
- numpy
- tqdm
- sklearn
- keras
- boto3
In the root directory, run following command to install the required libraries.
pip install -r requirement.txt
-
Download data
Download the datasets, then decompress them and put them in the corrsponding documents in
\data
. For example, put the data of Synthetic Dataset underdata/synthetic_dataset
.The dataset can be downloaded as following links:
-
Build data
Switch to the corresponding directory of the dataset and just run
preprocess.py
to preprocess data and generate a vocabulary of symptoms. -
Train and test
Train and test models by the follow commands.
Diaformer
# Train and test on Diaformer # Run on MuZhi dataset python Diaformer.py --dataset_path data/muzhi_dataset --batch_size 16 --lr 5e-5 --min_probability 0.009 --max_turn 20 --start_test 10 # Run on Dxy dataset python Diaformer.py --dataset_path data/dxy_dataset --batch_size 16 --lr 5e-5 --min_probability 0.012 --max_turn 20 --start_test 10 # Run on Synthetic dataset python Diaformer.py --dataset_path data/synthetic_dataset --batch_size 16 --lr 5e-5 --min_probability 0.01 --max_turn 20 --start_test 10
Diaformer_GPT2
# Train and test on GPT2 variant of Diaformer python GPT2_variant.py --dataset_path data/synthetic_dataset --batch_size 16 --lr 5e-5 --min_probability 0.01 --max_turn 20 --start_test 10
Diaformer_UniLM
# Train and test on UniLM variant of Diaformer python UniLM_variant.py --dataset_path data/synthetic_dataset --batch_size 16 --lr 5e-5 --min_probability 0.01 --max_turn 20 --start_test 10
Ablation study
# run ablation study # w/o Sequence Shuffle python Diaformer.py --dataset_path data/synthetic_dataset --batch_size 16 --lr 5e-5 --min_probability 0.01 --max_turn 20 --start_test 10 --no_sequence_shuffle # w/o Synchronous Learning python Diaformer.py --dataset_path data/synthetic_dataset --batch_size 16 --lr 5e-5 --min_probability 0.01 --max_turn 20 --start_test 10 --no_synchronous_learning # w/o Repeated Sequence python Diaformer.py --dataset_path data/synthetic_dataset --batch_size 16 --lr 5e-5 --min_probability 0.01 --max_turn 20 --start_test 10 --no_repeated_sequence
Generative inference
# save the model python Diaformer.py --dataset_path data/synthetic_dataset --batch_size 16 --lr 5e-5 --min_probability 0.01 --max_turn 20 --start_test 10 --model_output_path models # use the trained model to output the results python predict.py --dataset_path data/synthetic_dataset --min_probability 0.01 --max_turn 20 --pretrained_model models/ --result_output_path results.json