(Unofficial) Pytorch implementation of JointBERT
: BERT for Joint Intent Classification and Slot Filling
- Predict
intent
and slot
at the same time from one BERT model (=Joint model)
- total_loss = intent_loss + coef * slot_loss (Change coef with
--slot_loss_coef
option)
- If you want to use CRF layer, give
--use_crf
option
- python>=3.5
- torch==1.1.0
- transformers==2.2.2
- seqeval==0.0.12
- pytorch-crf==0.7.2
|
Train |
Dev |
Test |
Intent Labels |
Slot Labels |
ATIS |
4,478 |
500 |
893 |
21 |
120 |
Snips |
13,084 |
700 |
700 |
7 |
72 |
- The number of labels are based on the train dataset.
- Add
UNK
for labels (For intent and slot labels which are only shown in dev and test dataset)
- Add
PAD
for slot label
$ python3 main.py --task {task_name} \
--model_type {model_type} \
--model_dir {model_dir_name} \
--do_train --do_eval \
--use_crf
# For ATIS
$ python3 main.py --task atis \
--model_type bert \
--model_dir atis_model \
--do_train --do_eval
# For Snips
$ python3 main.py --task snips \
--model_type bert \
--model_dir snips_model \
--do_train --do_eval
- There should be a trained model before running prediction.
- You should write sentences in
preds.txt
in preds
directory.
- If your model is trained using CRF, you must give
--use_crf
option when running prediction.
$ python3 main.py --task snips \
--model_type bert \
--model_dir snips_model \
--do_pred \
--pred_dir preds \
--pred_input_file preds.txt
- Run 5 ~ 10 epochs (Record the best result)
- RoBERTa takes more epochs to get the best result compare to other models.
- ALBERT xxlarge sometimes can't converge well for slot prediction.
|
|
Intent acc (%) |
Slot F1 (%) |
Sentence acc (%) |
ATIS |
BERT |
97.87 |
95.46 |
|
|
BERT + CRF |
97.76 |
96.04 |
|
|
DistilBERT |
97.54 |
94.89 |
|
|
DistilBERT + CRF |
97.42 |
95.89 |
|
|
RoBERTa |
97.64 |
95.72 |
|
|
RoBERTa + CRF |
97.64 |
95.63 |
|
|
ALBERT |
98.20 |
95.59 |
|
|
ALBERT + CRF |
|
|
|
Snips |
BERT |
98.29 |
96.05 |
|
|
BERT + CRF |
|
|
|
|
DistilBERT |
98.42 |
94.10 |
|
|
DistilBERT + CRF |
|
|
|
|
RoBERTa |
98.14 |
94.60 |
|
|
RoBERTa + CRF |
|
|
|
|
ALBERT |
98.57 |
97.48 |
|
|
ALBERT + CRF |
|
|
|
- 2019/12/03: Add DistilBert and RoBERTa result
- 2019/12/14: Add Albert (large v1) result
- 2019/12/22: Available to predict sentences
- 2019/12/26: Add Albert (xxlarge v1) result
- 2019/12/29: Add CRF option
- 2019/12/30: Available to check
sentence-level semantic frame accuracy