After fine-tuning your own model, you would need to convert the created checkpoint to a format that is supported by pytorch. There is a script available within transformers library1 that allows the conversion. Since we are not using bert-base-uncased as the pretrained model, we need to change:
model = DPRContextEncoder(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
with:
model = DPRContextEncoder(DPRConfig(**BertConfig.get_config_dict("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")[0]))
Similarly, we need to change code for DPRQuestionEncoder:
model = DPRQuestionEncoder(DPRConfig(**BertConfig.get_config_dict("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")[0]))
and for DPRReader as:
model = DPRReader(DPRConfig(**BertConfig.get_config_dict("gdario/biobert_bioasq")[0]))
convert_dpr_original_checkpoint_to_pytorch.py script is already changed in the above-mentioned way.
After editing the script in the above-mentioned way, we need to run three scripts to first convert ctx_encoder:
python convert_dpr_original_checkpoint_to_pytorch.py --type ctx_encoder --src pipeline1_baseline/cp_models/dpr_biencoder.29 --dest SleepQA/models/pytorch/ctx_encoder
then question_encoder:
python convert_dpr_original_checkpoint_to_pytorch.py --type question_encoder --src pipeline1_baseline/cp_models/dpr_biencoder.29 --dest SleepQA/models/pytorch/question_encoder
and finally reader:
python convert_dpr_original_checkpoint_to_pytorch.py --type reader --src pipeline1_baseline/cp_models/dpr_extractive_reader.1.250 --dest SleepQA/models/pytorch/reader
After running three above mentioned scripts, we need to download tokenizer_config.json and vocab.txt files from respective Hugging Face repositories: PubMedBERT2 for ctx_encoder and question_encoder, and BioBERT BioASQ3 for reader.
qa_system.py script allows us to use fine-tuned models in a QA pipeline:
- generate_dense_encodings function generates encodings for text corpus,
- dense_retriever function retrieves the most relevant passage for the given question, and
- extractive_reader function retrieves the most relevant text span for the given question.