Skip to content

Commit 4bb5dc6

Browse files
author
yunfan
committed
update
1 parent 65efd48 commit 4bb5dc6

21 files changed

+296
-2096
lines changed

README.md

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,48 @@ The architecture of CPT is a variant of the full Transformer and consists of thr
2222
2. **Understanding Decoder** (U-Dec): a shallow Transformer encoder with fully-connected self-attention, which is designed for NLU tasks. The input of U-Dec is the output of S-Enc.
2323
3. **Generation Decoder** (G-Dec): a Transformer decoder with masked self-attention, which is designed for generation tasks with auto-regressive fashion. G-Dec utilizes the output of S-Enc with cross-attention.
2424

25-
## Downloads & Usage
25+
## Pre-Trained Models
26+
We provide the pre-trained weights of CPT and Chinese BART with source code, which can be directly used in Huggingface-Transformers.
2627

27-
Coming soon.
28+
- **`Chinese BART-base`**: 6 layers Encoder, 6 layers Decoder, 12 Heads and 768 Model dim.
29+
- **`Chinese BART-large`**: 12 layers Encoder, 12 layers Decoder, 16 Heads and 1024 Model dim.
30+
- **`CPT-base`**: 10 layers S-Enc, 2 layers U-Dec/G-Dec, 12 Heads and 768 Model dim.
31+
- **`CPT-large`**: 20 layers S-Enc, 4 layers U-Dec/G-Dec, 16 Heads and 1024 Model dim.
2832

29-
## Chinese BART
33+
The pre-trained weights can be downloaded here.
34+
| Model | `MODEL_NAME`|
35+
| --- | --- |
36+
| **`Chinese BART-base`** | [fnlp/bart-base-chinese](https://huggingface.co/fnlp/bart-base-chinese) |
37+
| **`Chinese BART-large`** | [fnlp/bart-large-chinese](https://huggingface.co/fnlp/bart-large-chinese) |
38+
| **`CPT-base`** | [fnlp/cpt-base](https://huggingface.co/fnlp/cpt-base) |
39+
| **`CPT-large`** | [fnlp/cpt-large](https://huggingface.co/fnlp/cpt-large) |
3040

31-
We also provide a pre-trained Chinese BART as a byproduct. The BART models is pre-trained with the same corpora, tokenization and hyper-parameters of CPT.
3241

33-
#### Load with Huggingface-Transformers
34-
35-
Chinese BART is available in **base** and **large** versions, and can be loaded with Huggingface-Transformers. The example code is as follows, where `MODEL_NAME` is `fnlp/bart-base-chinese` or `fnlp/bart-large-chinese` for **base** or **large** size of BART, respectively.
42+
To use CPT, please import the file `finetune/modeling_cpt.py` that define the architecture of CPT into your project.
43+
Then, use the PTMs as the following example, where `MODEL_NAME` is the corresponding string that refers to the model.
3644

45+
For CPT:
3746
```python
38-
>>> tokenizer = BertTokenizer.from_pretrained("MODEL_NAME")
39-
>>> model = BartForConditionalGeneration.from_pretrained("MODEL_NAME")
47+
from modeling_cpt import BertTokenizer, CPTForConditionalGeneration
48+
tokenizer = BertTokenizer.from_pretrained("MODEL_NAME")
49+
model = CPTForConditionalGeneration.from_pretrained("MODEL_NAME")
50+
print(model)
4051
```
4152

42-
The checkpoints of Chinese BART can be downloaded here.
53+
For Chinese BART:
54+
```python
55+
from transformers import BertTokenizer, BartForConditionalGeneration
56+
tokenizer = BertTokenizer.from_pretrained("MODEL_NAME")
57+
model = BartForConditionalGeneration.from_pretrained("MODEL_NAME")
58+
print(model)
59+
```
4360

44-
- [fnlp/bart-base-chinese](https://huggingface.co/fnlp/bart-base-chinese): 6 layers encoder, 6 layers decoder, 12 heads and 768 model dim.
45-
- [fnlp/bart-large-chinese](https://huggingface.co/fnlp/bart-large-chinese): 12 layers encoder, 12 layers decoder, 16 heads and 1024 model dim.
61+
## Pre-Training
62+
Pre-training code and examples can be find [Here](pretrain/README.md).
4663

4764

65+
## Fine-Tuning
66+
Fine-tuning code and examples can be find [Here](finetune/README.md).
4867

4968
## Citation
5069

finetune/REAMDE.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,25 @@
1-
# Fine-Tuning of CPT
1+
# Fine-Tuning CPT
2+
3+
This repo contains the fine-tuning code for CPT on multiple NLU and NLG tasks, such as text classification, machine reading comprehension (MRC), sequence labeling and text generation, etc.
4+
5+
## Requirement
6+
- pytorch==1.8.1
7+
- transformers==4.2.0
8+
9+
## Run
10+
The code and running examples are listed in the corresponding folders of the fine-tuning tasks.
11+
12+
- **`classification`**: [Fine-tuning](classification/REAMDE.md) for sequence classification with either external classifiers or prompt-based learning.
13+
- **`cws`**: [Fine-tuning](cws/REAMDE.md) for Chinese Word Segmentation with external classifiers.
14+
- **`generation`**: [Fine-tuning](generation/REAMDE.md) for abstractive summarization and data-to-text generation.
15+
- **`mrc`**: [Fine-tuning](mrc/REAMDE.md) for Span-based Machine Reading Comprehension with exteranl classifiers.
16+
- **`ner`**: [Fine-tuning](ner/REAMDE.md) for Named Entity Recognition.
17+
18+
You can also fine-tuning CPT on other tasks by adding `modeling_cpt.py` into your project and use the following code to use CPT.
19+
20+
```python
21+
from modeling_cpt import BertTokenizer, CPTForConditionalGeneration
22+
tokenizer = BertTokenizer.from_pretrained("MODEL_NAME")
23+
model = CPTForConditionalGeneration.from_pretrained("MODEL_NAME")
24+
print(model)
25+
```

finetune/__init__.py

Whitespace-only changes.

finetune/classification/REAMDE.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Fine-tuning CPT for Sequence Classification
2+
3+
## Dataset
4+
The dataset of **CLUE** can be downloaded [HERE](https://github.com/CLUEbenchmark/CLUE)
5+
6+
## Train and Evaluate
7+
To train and evaluate **CPT$_u$**, **CPT$_g$** and **CPT$_{ug}$**, run the python file `run_clue_classifier.py`, with the argument `--cls_mode` be set to `1`, `2` and `3`, respectively. Following is a script example to run base version of **CPT$_u$** on **AFQMC** dataset.
8+
9+
```bash
10+
export MODEL_TYPE=cpt-base
11+
export MODEL_NAME=fnlp/cpt-base
12+
export CLUE_DATA_DIR=/path/to/clue_data_dir
13+
export TASK_NAME=afqmc
14+
export CLS_MODE=1
15+
python run_clue_classifier.py \
16+
--model_type=$MODEL_TYPE \
17+
--model_name_or_path=$MODEL_NAME \
18+
--cls_mode=$CLS_MODE \
19+
--task_name=$TASK_NAME \
20+
--do_train=True \
21+
--do_predict=1 \
22+
--no_tqdm=False \
23+
--data_dir=$CLUE_DATA_DIR/${TASK_NAME}/ \
24+
--max_seq_length=512 \
25+
--per_gpu_train_batch_size=16 \
26+
--gradient_accumulation_steps 1 \
27+
--per_gpu_eval_batch_size=64 \
28+
--weight_decay=0.1 \
29+
--adam_epsilon=1e-6 \
30+
--adam_beta1=0.9 \
31+
--adam_beta2=0.999 \
32+
--max_grad_norm=1.0 \
33+
--learning_rate=1e-5 \
34+
--power=1.0 \
35+
--num_train_epochs=5.0 \
36+
--warmup_steps=0.1 \
37+
--logging_steps=200 \
38+
--save_steps=999999 \
39+
--output_dir=output/ft/$MODEL_TYPE/${TASK_NAME}/ \
40+
--overwrite_output_dir=True \
41+
--seed=42
42+
```
43+
44+
45+
## Prompt-based Fine-Tuning
46+
To train and evaluate **CPT$_{u+p}$** and **CPT$_{g+p}$**, run the python file `run_clue_prompt.py` with the argument `--cls_mode` be set to `1` and `2`, respectively. Following is a script example to run base version of **CPT$_{u+p}$** on **AFQMC** dataset.
47+
48+
```bash
49+
export MODEL_TYPE=cpt-base
50+
export MODEL_NAME=fnlp/cpt-base
51+
export CLUE_DATA_DIR=/path/to/clue_data_dir
52+
export TASK_NAME=afqmc
53+
export NUM_TRAIN=-1
54+
export PATTERN_IDS=0
55+
export CLS_MODE=1
56+
python run_clue_prompt.py \
57+
--pattern_ids $PATTERN_IDS \
58+
--cls_mode 1 \
59+
--data_dir=$CLUE_DATA_DIR/${TASK_NAME}/ \
60+
--model_type $MODEL_TYPE \
61+
--model_name_or_path $MODEL_NAME \
62+
--max_seq_length 512 \
63+
--task_name $TASK_NAME \
64+
--output_dir output/prompt/$MODEL_TYPE/${TASK_NAME}/ \
65+
--train_examples $NUM_TRAIN \
66+
--weight_decay 0.1 \
67+
--learning_rate 1e-5 \
68+
--power 1.0 \
69+
--warmup_steps 0.1 \
70+
--split_examples_evenly \
71+
--num_train_epochs 5 \
72+
--eval_steps 200 \
73+
--per_gpu_train_batch_size 16 \
74+
--gradient_accumulation_steps 1 \
75+
--per_gpu_eval_batch_size 32 \
76+
--do_train \
77+
--do_eval
78+
```

finetune/classification/run_clue_classifier.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
from transformers import glue_processors as processors
4444
from transformers.models.bert.tokenization_bert import BertTokenizer
4545
from data_processors import clue_output_modes, clue_processors
46+
47+
import sys
48+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
4649
from modeling_cpt import CPTForSequenceClassification, CPTConfig
4750

4851

@@ -620,14 +623,19 @@ def get_dataset(args, task, tokenizer, part='train'):
620623
return dataset
621624

622625
def get_model(args, model_name_or_path, num_labels):
623-
if 'cpt' in model_name_or_path:
626+
tokenizer = BertTokenizer.from_pretrained(
627+
args.tokenizer_name if args.tokenizer_name else model_name_or_path,
628+
do_lower_case=args.do_lower_case,
629+
cache_dir=args.cache_dir if args.cache_dir else None,
630+
)
631+
if 'cpt' in args.model_type:
624632
config = CPTConfig.from_pretrained(
625633
model_name_or_path,
626634
num_labels=num_labels,
627635
finetuning_task=args.task_name,
628636
cache_dir=args.cache_dir if args.cache_dir else None)
629637
# config.consist_lambda = args.consist_lambda
630-
config.cls_mode = args.ft_mode
638+
config.cls_mode = args.cls_mode
631639
model = CPTForSequenceClassification.from_pretrained(
632640
model_name_or_path,
633641
from_tf=bool(".ckpt" in model_name_or_path),
@@ -641,11 +649,6 @@ def get_model(args, model_name_or_path, num_labels):
641649
finetuning_task=args.task_name,
642650
cache_dir=args.cache_dir if args.cache_dir else None,
643651
)
644-
tokenizer = AutoTokenizer.from_pretrained(
645-
args.tokenizer_name if args.tokenizer_name else model_name_or_path,
646-
do_lower_case=args.do_lower_case,
647-
cache_dir=args.cache_dir if args.cache_dir else None,
648-
)
649652
model = AutoModelForSequenceClassification.from_pretrained(
650653
pretrained_model_name_or_path=args.config_name if args.config_name else model_name_or_path,
651654
from_tf=bool(".ckpt" in model_name_or_path),
@@ -778,7 +781,7 @@ def main():
778781
type=str2bool, default=True,
779782
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
780783
)
781-
parser.add_argument("--ft_mode", default=1, type=int, help="CPT fine-tune `mode`")
784+
parser.add_argument("--cls_mode", default=1, type=int, help="CPT fine-tune `mode`")
782785
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
783786
parser.add_argument("--no_tqdm", type=str2bool, default=False, help="Avoid using tqdm when available")
784787
parser.add_argument("--sample_tokenize", type=str2bool, default=False, help="using sampling when tokenize")

finetune/classification/run_clue_prompt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from prompt import prompt, templates, log
1111
import json
1212
from transformers import glue_processors, WEIGHTS_NAME
13-
import sys
14-
sys.path.append('..')
1513
from data_processors import clue_processors
14+
15+
import sys
16+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
1617
from modeling_cpt import CPTForMaskedLM
18+
1719
import glob
1820

1921
import torch.multiprocessing

finetune/cws/REAMDE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Fine-tuning CPT for CWS
2+
3+
## Dataset
4+
The dataset **MSR** and **PKU** is from **SIGHAN2005**, which can be downloaded [HERE](http://sighan.cs.uchicago.edu/bakeoff2005/).
5+
6+
## Train and Evaluate
7+
8+
To train and evaluate CPT on CWS dataset, run the python file `run_cws.py`. Following is a script example to run base version of **CPT$_u$** on **MSR** dataset.
9+
10+
```bash
11+
export MODEL_TYPE=cpt-base
12+
export MODEL_NAME=fnlp/cpt-base
13+
export DATA_DIR=/path/to/cws_data_dir
14+
python run_cws.py \
15+
--bert_name=$MODEL_NAME \
16+
--data_dir=$DATA_DIR \
17+
--dataset=msr \
18+
--lr=2e-5 \
19+
--batch_size=16 \
20+
--epoch=10 \
21+
```

finetune/cws/run_cws.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616

1717
from model import CWSModel
1818
from utils import DataTrainingArguments, ModelArguments, load_json
19+
20+
import sys
21+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
1922
from modeling_cpt import CPTModel
2023

2124
parser = argparse.ArgumentParser()
22-
parser.add_argument("--bert_name",default='/remote-home/share/yfshao/bart-zh/arch24-4-new-iter10w',type=str)
25+
parser.add_argument("--bert_name",default='/path/to/model/',type=str)
2326
parser.add_argument("--dataset", default="msr",type=str)
2427
parser.add_argument("--lr",default=2e-5,type=float)
2528
parser.add_argument("--batch_size",default='16',type=str)
2629
parser.add_argument("--epoch",default='10',type=str)
27-
parser.add_argument("--data_dir",default="../../data",type='str')
30+
parser.add_argument("--data_dir",default="/path/to/dataset/",type='str')
2831
args = parser.parse_args()
2932
arg_dict=args.__dict__
3033

finetune/generation/REAMDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Fine-tuning CPT for Text Generation

finetune/generation/run_gen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616
from transformers.trainer_utils import is_main_process
1717
from datasets import load_metric,Dataset
1818
from utils import DataTrainingArguments, ModelArguments, load_json
19+
20+
import sys
21+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
1922
from modeling_cpt import CPTModel, CPTForConditionalGeneration
2023

2124

2225
parser = argparse.ArgumentParser()
23-
parser.add_argument("--bert_name",default='/path/to/cpt/',type=str)
26+
parser.add_argument("--bert_name",default='/path/to/model',type=str)
2427
parser.add_argument("--dataset", default="lcsts",type=str)
2528
parser.add_argument("--lr",default=2e-5,type=float)
2629
parser.add_argument("--batch_size",default='50',type=str)
2730
parser.add_argument("--epoch",default='5',type=str)
28-
parser.add_argument("--data_dir",default="/path/to/dataset/",type='str')
31+
parser.add_argument("--data_dir",default="/path/to/dataset/",type=str)
2932
args = parser.parse_args()
3033
arg_dict=args.__dict__
3134

@@ -275,4 +278,4 @@ def on_evaluate(self, args, state, control, **kwargs):
275278
test_preds = [pred.strip() for pred in test_preds]
276279
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt")
277280
with open(output_test_preds_file, "w",encoding='UTF-8') as writer:
278-
writer.write("\n".join(test_preds))
281+
writer.write("\n".join(test_preds))

finetune/modeling_cpt.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,15 @@
5050

5151
from torch.nn import LayerNorm
5252

53-
# For cuda fused ops
54-
# from megatron.model import LayerNorm
55-
# from megatron.model.transformer import ParallelMLP
56-
# from megatron.model.fused_bias_gelu import bias_gelu_impl
57-
# from megatron import mpu
58-
# from megatron import get_args
59-
# from megatron.model.enums import AttnMaskType, LayerType, AttnType
60-
61-
6253
logger = logging.get_logger(__name__)
6354

64-
_CHECKPOINT_FOR_DOC = "fudannlp/cpt-large"
55+
_CHECKPOINT_FOR_DOC = "fnlp/cpt-large"
6556
_CONFIG_FOR_DOC = "CPTConfig"
6657
_TOKENIZER_FOR_DOC = "CPTTokenizer"
6758

6859

6960
CPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
70-
"fudannlp/cpt-large",
61+
"fnlp/cpt-large",
7162
]
7263

7364

@@ -114,17 +105,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
114105

115106
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
116107

117-
# def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
118-
# """
119-
# Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
120-
# """
121-
# bsz, src_len = mask.size()
122-
# tgt_len = tgt_len if tgt_len is not None else src_len
123-
124-
# expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
125-
126-
# inverted_mask = (expanded_mask < 0.5)
127-
# return inverted_mask
128108
def attention_mask_func(attention_scores, attention_mask):
129109
return attention_scores + attention_mask
130110

finetune/mrc/REAMDE.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Fine-tuning CPT for Sequence Classification
2+
3+
## Dataset
4+
The dataset of **CMRC2018** can be downloaded [HERE](https://github.com/CLUEbenchmark/CLUE). And **DRCD** can be downloaded [HERE](https://github.com/DRCKnowledgeTeam/DRCD).
5+
6+
## Train and Evaluate
7+
To train and evaluate **CPT$_u$**, **CPT$_g$** and **CPT$_{ug}$**, run the python file `run_mrc.py`, with the argument `--cls_mode` be set to `1`, `2` and `3`, respectively. Following is a script example to run base version of **CPT$_u$** on **DRCD** dataset.
8+
9+
```bash
10+
export MODEL_TYPE=cpt-base
11+
export MODEL_NAME=fnlp/cpt-base
12+
export CLUE_DATA_DIR=/path/to/mrc_data_dir
13+
export TASK_NAME=drcd
14+
export CLS_MODE=1
15+
python run_mrc.py \
16+
--fp16 \
17+
--model_type $MODEL_TYPE \
18+
--train_epochs=5 \
19+
--do_train=1 \
20+
--do_predict=1 \
21+
--n_batch=16 \
22+
--gradient_accumulation_steps 4 \
23+
--lr=3e-5 \
24+
--dropout=0.2 \
25+
--CLS_MODE=$CLS_MODE \
26+
--warmup_rate=0.1 \
27+
--weight_decay_rate=0.01 \
28+
--max_seq_length=512 \
29+
--eval_steps=200 \
30+
--task_name=$TASK_NAME \
31+
--init_restore_dir=$MODEL_NAME \
32+
--train_dir=$CLUE_DATA_DIR/$TASK_NAME/train_features.json \
33+
--train_file=$CLUE_DATA_DIR/$TASK_NAME/train.json \
34+
--dev_dir1=$CLUE_DATA_DIR/$TASK_NAME/dev_examples.json \
35+
--dev_dir2=$CLUE_DATA_DIR/$TASK_NAME/dev_features.json \
36+
--dev_file=$CLUE_DATA_DIR/$TASK_NAME/dev.json \
37+
--test_file=$CLUE_DATA_DIR/$TASK_NAME/test.json \
38+
--test_dir1=$CLUE_DATA_DIR/$TASK_NAME/test_examples_$MODEL_TYPE.json \
39+
--test_dir2=$CLUE_DATA_DIR/$TASK_NAME/test_features_$MODEL_TYPE.json \
40+
--checkpoint_dir=output/$MODEL_TYPE/$TASK_NAME/
41+
```

0 commit comments

Comments
 (0)