This repository contains the source code and trained model for a large-scale pretrained dialogue response generation model. See more details on our project page
The repository is based on huggingface pytorch-transformer and OpenAI GPT-2, containing data extraction script, model training code and pretrained small (117M) medium (345M) and large (762M) model checkpoint.
The model is trained on 147M multi-turn dialogue from Reddit discussion thread. The largest model can be trained in several hours on a 8 V100 machines (however this is not required), with distributed training and FP16 option.
The include script can be used to reproduce the results of DSTC-7 grounded dialogue generation challenge and a 6k multi-reference dataset created from Reddit data.
Project webpage: https://www.microsoft.com/en-us/research/project/large-scale-pretraining-for-response-generation/
This github repository will be updated soon. Please stay tuned.
This code can be run on CPU, but it would be slow. We would recommend to use GPU to train and finetune all models. There is no minimal limit of the number of GPUs. However, if using distributed train for multiple GPUs configuration, the speed-up vs the number of GPUs is roughly sub-linear. To simulate the same batchsize when using less GPUs, please use a larger gradient_accumulation_steps
in model training.
The 117M and 345M model can be loaded in a single GPU with 12G memory. The 762M model would require a single GPU that has greater than 16G memory for efficient training. The training speed on a benchmark data with 50M training instances and V100 GPUs:
n_gpu | epoch time (min) | token/sec |
---|---|---|
1 | 158 | 25466 |
2 | 96 | 41861 |
4 | 73 | 54994 |
8 | 65 | 63612 |
Fine-tuning from our pretrained model on a new dataset typically requires 1-2 epochs.
We created a demo script demo.py
to ease the difficulty of the deployment of this system. The demo.py
contains a pipeline of model downloading, data extraction, data preprocessing and model training over a dummy dataset within one commandline.
Please use the below commandlines to clone, install the requirements and load the Conda environment (Note that Cuda 10 is required):
sudo apt-get install -y make wget gzip bzip2 xz-utils zstd
git clone https://github.com/microsoft/DialoGPT.git
cd DialoGPT
conda env create -f LSP-linux.yml -n LSP
conda activate LSP
If you run this on an architecture other than Linux, please use LSP-generic.yml
instead of LSP-linux.yml
but please note that the generic one is not tested in all platform, so the stablity can not be gauranteed.
To use fp16 training, please install apex by using commands below
conda activate LSP
git clone https://github.com/NVIDIA/apex
cd apex
git reset --hard 3d01e4a0a188cc8df54bc6e44cf5eb40ff6b4cc5
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
python3.6 demo.py
To start, first install the docker and Nvidia-docker from their official repos. The image environment for running the code can be loaded as below:
Nvidia-docker v2.*
$ docker run --gpus all --ipc=host --rm -it -v $PWD:/workspace --network=host icaruszyz/large-scale-training:dialogpt bash
Nvidia-docker v1.*
$ nvidia-docker --rm -it -v $PWD:/workspace --network=host icaruszyz/large-scale-training:dialogpt bash
Inside the docker container, run
python demo.py
This section explains all components in the demo.py
.
Before running demo.py
, you can set DATA_FOLDER (default value ./models
) in demo.py
as the place you want to download all the data and pretrained/fine-tuned models. Then simply run
python demo.py
to
- automatically download models and data,
- prepare raw data into db that is ready to use for the program,
- generate a training scripts.
Note that by default the demo.py
will use a dummy data, please specify the Reddit training data by using option --data
. Three options are available:dummy
,small
and full
.
python demo.py --data small
python demo.py --data full
The small Reddit data is around 140MB and the full Reddit data is more than 30GB. You can prepare a cup of coffee when processing with the full Reddit data because it takes a long time!
The pretrained and fine-tuned models are available on azure blobstorage here.
Please run/see demo.py
for more details about how to download/use those models.
First, use the prepare4db.sh
to convert a tsv data file into the correct format that the following script can recognize.
The trainig data need to be then processed into a database file with below commandline:
python prepro.py --corpus $DATA_PATH
The training script can be used in single GPU or multiple GPU settings (distributed training across multiple GPUs within a single node):
python ./LSP_train.py # Single GPU training
python -m torch.distributed.launch --nproc_per_node=8 ./LSP_train.py # Training on 8 GPUs
The training script accept several arguments to tweak the training:
Argument | Type | Default value | Description |
---|---|---|---|
max_seq_length | int |
128 |
Maximum number of tokens for each training instance. |
train_input_file | str |
"" |
Path of the training dataset in a .db format |
eval_input_file | str |
"" |
Path of the validation set in a tsv format |
continue_from | int |
0 |
Resuming the training after a specified number of steps |
fp16 | boolean |
True |
Whether to use 16-bits floating point for model training. |
train_batch_size | int |
4 |
Batch size for training |
valid_batch_size | int |
4 |
Batch size for validation |
gradient_accumulation_steps | int |
2 |
Accumulate gradients on several steps |
learning_rate | float |
1e-5 |
Learning rate |
lr_schedule | str |
noam |
Learning rate schedule can be chosen from [noam , noamwd , BERT , None ] |
num_optim_steps | int |
1000000 |
Number of training optimization steps |
no_token_id | boolean |
True |
If set True, using all-zeros token-type embedding. |
During the training, two log files will be updated. The train_log.txt
and eval_log.txt
contains the model loss, perplexity and training speed (tokens/sec) statistics for the training and dev set.
The log file and saved model checkpoint can be found in ./models/output_model
We note that even with properly filtered Reddit dataset, sometimes our model can still generate moderately toxic/inappropriate responses. Due to this reason, we are unable to provide the decoding script at this time (The live demo and decoding script access is upon invitation only now ). We are currently still working on a controlled decoding method to prevent this system from toxic generation. Please stay tuned.
We release 6 fine-tuned models which can be further fine-tuned on low-resource user-customized dataset. The total parameters in these models range from 117M to 762M, in accord with OpenAI GPT-2 model sizes.
Model | Download |
---|---|
DialoGPT 762M model | link |
DialoGPT 345M model | link |
DialoGPT 117M model | link |
The model files can be loaded exactly as the GPT-2 model checkpoint from Huggingface pytorch-transformer repository.
Our model achieved the state-of-the-art results in DSTC-7 Challenge response generation task.
Experiment | NIST1 | NIST2 | NIST3 | NIST4 | BLEU1 | BLEU2 | BLEU3 | BLEU4 | METEOR | entropy1 | entropy2 | entropy3 | entropy4 | diversity1 | diversity2 | avg_len |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Human | 2.4237 | 2.6244 | 2.6472 | 2.65 | 0.3408 | 0.1235 | 0.0572 | 0.0313 | 0.0831 | 6.5893 | 9.7423 | 10.4101 | 10.4450 | 0.1666 | 0.6701 | 18.7568 |
DSTC-7 Winner | 2.3408 | 2.5102 | 2.522 | 2.523 | 0.4122 | 0.1435 | 0.0501 | 0.0183 | 0.0807 | 5.3832 | 7.6065 | 8.5304 | 9.0298 | 0.1089 | 0.3249 | 15.1327 |
DialoGPT | 2.5863 | 2.804 | 2.823 | 2.8246 | 0.3927 | 0.1416 | 0.0555 | 0.0231 | 0.0851 | 5.5791 | 8.5109 | 9.6872 | 10.0765 | 0.0913 | 0.3973 | 16.9484 |
DialoGPT(beam search) | 2.5943 | 2.9163 | 2.9624 | 2.9681 | 0.4238 | 0.1918 | 0.1027 | 0.0605 | 0.0929 | 6.0815 | 8.7379 | 9.4037 | 9.5697 | 0.1573 | 0.5103 | 14.1603 |
Note that the superior automatic evaluation comparing to human responses does not necessary imply that our model achieves human parity. Please check out our paper for more detailed analysis.
To fine-tune the 345M
DialoGPT model on the DSTC-7 challenge data on a server with 8 V100 GPUs, please run the following commandline (The DSTC data can be found at DSTC-7 repo):
python3 -m torch.distributed.launch --nproc_per_node=8 train_LSP.py --init_checkpoint ./models/medium/medium_ft.pkl --train_input_file ./data/DSTC_train.db --eval_input_file ./data/DSTC_valid.tsv --model_name_or_path ./model/medium/ --learning_rate 1e-4 --train_batch_size 64 --eval_batch_size 64 --no_token_id
The trained model can be found at DSTC medium model
-
Please downloads the following 3rd-party packages and save into the empty folder
3rdparty
:- mteval-v14c.pl to compute NIST. You may need to install the following perl modules (e.g. by
cpan install
): XML:Twig, Sort:Naturally and String:Util. - meteor-1.5 to compute METEOR. It requires Java.
- mteval-v14c.pl to compute NIST. You may need to install the following perl modules (e.g. by
-
Please follow the DSTC-7 official repo to extract the data, and put
data-official-test/test.refs.txt
into./dstc/data/
folder. -
Run the extraction script below to produce the human response hypothesis file
human.resp.txt
:python extract_human.py
-
Finally, to reproduce the results of human hypothesis on DSTC dataset, please run following commands under the repo folder:
python batch_eval.py
The evaluation results will be generated in the folder ./dstc/eval/
We test on 6K multi-ref dataset from Reddit (this test data will be release soon). The results are summarized in below
Experiment | NIST1 | NIST2 | NIST3 | NIST4 | BLEU1 | BLEU2 | BLEU3 | BLEU4 | METEOR | entropy4 | diversity1 | diversity2 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Human response | 2.99 | 3.41 | 3.83 | 4.25 | 39.61% | 17.90% | 10.71% | 7.48% | 10.64% | 11 | 14.50% | 63.00% |
DialoGPT 117M | 2.25 | 2.39 | 2.41 | 2.41 | 35.43% | 10.54% | 3.85% | 1.55% | 7.53% | 10.78 | 8.60% | 39.90% |
DialoGPT 345M | 2.67 | 3 | 3.05 | 3.06 | 40.97% | 16.96% | 8.31% | 4.56% | 9.81% | 9.13 | 6.80% | 26.30% |
DialoGPT 762M | 2.51 | 2.84 | 2.89 | 2.9 | 44.45% | 18.66% | 9.32% | 5.25% | 9.66% | 9.72 | 7.76% | 29.93% |
- Multi-turn generation examples from an interactive environment:
Role | Response |
---|---|
User | Does money buy happiness? |
Bot | Depends how much money you spend on it . |
User | What is the best way to buy happiness ? |
Bot | You just have to be a millionaire by your early 20s, then you can be happy . |
User | This is so difficult ! |
Bot | You have no idea how hard it is to be a millionaire and happy . There is a reason the rich have a lot of money |
- Commonsense question answering examples
User | Bot |
---|---|
who is the first president of the United States | George Washington |
what is the boiling point of water? | I think it's about 212 F. |
which one is bigger, sun or moon? | The sun . |
which animal has black and white stripes? | A black and white striped cat . |
- Handling context that does not appear in training examples (note that training examples are sampled from 2005 to 2017)
User | Bot |
---|---|
The trading war between China and US is still happening . | I'm sure the Chinese government and the U . S . will be fine with that as long as they keep their arms on the table and the money flowing . |
who won the world cup in 2018 ? | Germany , I believe . |
Nvidia's Titan RTX is really good . | I have this card and i love it . No complaints . |
Can Porsche beat Tesla with its new Taycan EV ? | I don't think so . The Tesla is more powerful than the porsche . |
- Our bot's response to philosophy questions
User | Bot |
---|---|
What is the meaning of a good life ? | I think it's that we're all connected to our past lives and the meaning of life is to live the way you want to and how you want to . |
Which is more important -- being successful or being happy ? | Being successful is just endless . Being happy is the only thing that matters . |
Is peace the only way to stop war ? | Not if the world is not willing to work with you . Peace is the ultimate goal , but peace requires more than that . |
How much freedom should people have ? | The right amount of freedom is the freedom to do as you please , as long as you don't hurt people or property . |
- Finally
User | Bot |
---|---|
What do you think of the DialoGPT repo ? | It's pretty good , but it's very easy to find a bug . |
Please start a issue if you spot any :)
The live demo and decoding script access is upon invitation only now. Please stayed tuned for the full release.
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.
This repository aims to facilitate research in large-scale pretraining for conversational data. This toolkit contains only part of the modeling machinery needed to actually produce a model weight file in a running dialog. On its own, this model provides only information about the weights of various text spans; in order for a researcher to actually use it, they will need to bring conversational data of their own and decode the response generation from the pretrained system. Microsoft is not responsible for any generation from the 3rd party utilization of the pretrained system.
To be updated.