Skip to content

Commit

Permalink
training pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
姚星丞 committed Nov 5, 2021
1 parent 2e6c479 commit 44e18ed
Show file tree
Hide file tree
Showing 14 changed files with 1,340 additions and 3 deletions.
Binary file added .DS_Store
Binary file not shown.
47 changes: 44 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,45 @@
# TLM
Hi~ This is the github repository for the paper [NLP From Scratch Without Large-Scale Pretraining: A Simple and Efficient Framework]().
## NLP From Scratch Without Large-Scale Pretraining
This repository contains the code, pre-trained model checkpoints and curated datasets for our paper: [NLP From Scratch Without Large-Scale Pretraining: A Simple and Efficient Framework]().

We are now working day and night to get our code, model checkpoints and data ready for the public. Please watch us and stay tuned!
In our proposed framework, named ***TLM*** (task-driven language modeling), instead of training a language model over the entire general corpus and then finetuning it on task data, we first usetask data as queries to retrieve a tiny subset of the general corpus, and then perform joint learning on both the task objective and self-supervised language modeling objective.
![](./fig/framework.png)

### Requirements
We implement our models and training loops based on the opensource products from [HuggingFace](https://huggingface.co/). The core denpencies of this repository are listed in `requirements.txt`, which can be installed through:
```
pip install -r requirements.txt
```
All our experiments are conducted on a node with 8 [A100 40GB SXM](https://www.nvidia.cn/data-center/a100/) gpus. Different computational devices may result slightly different results from the reported ones.

### Models and Datasets

We release the trained models on 8 tasks with 3 different scales, together with the task datasets and selected external data. Our released model checkpoints, datasets and the performance of each model for each task are listed in the following table.
| | [AGNews](https://huggingface.co/datasets/yxchar/ag-tlm) | [Hyp.](https://huggingface.co/datasets/yxchar/hyp-tlm)| [Help.](https://huggingface.co/datasets/yxchar/amazon-tlm)| [IMDB](https://huggingface.co/datasets/yxchar/imdb-tlm)| [ACL.](https://huggingface.co/datasets/yxchar/citation_intent-tlm)| [SciERC](https://huggingface.co/datasets/yxchar/sciie-tlm)| [Chem.](https://huggingface.co/datasets/yxchar/chemprot-tlm)|[RCT](https://huggingface.co/datasets/yxchar/rct-20k-tlm) |
|-------------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|*Small*| [93.74](https://huggingface.co/yxchar/tlm-ag-small-scale)| [93.53](https://huggingface.co/yxchar/tlm-hyp-small-scale)| [70.54](https://huggingface.co/yxchar/tlm-amazon-small-scale)| [93.08](https://huggingface.co/yxchar/tlm-imdb-small-scale)|[69.84](https://huggingface.co/yxchar/tlm-citation_intent-small-scale) |[80.51](https://huggingface.co/yxchar/tlm-sciie-small-scale) | [81.99](https://huggingface.co/yxchar/tlm-chemprot-small-scale)|[86.99](https://huggingface.co/yxchar/tlm-rct-20k-small-scale)|
|*Medium*|[93.96](https://huggingface.co/yxchar/tlm-ag-medium-scale)|[94.05](https://huggingface.co/yxchar/tlm-hyp-medium-scale)|[70.90](https://huggingface.co/yxchar/tlm-amazon-medium-scale)|[93.97](https://huggingface.co/yxchar/tlm-imdb-medium-scale)|[72.37](https://huggingface.co/yxchar/tlm-citation_intent-medium-scale)|[81.88](https://huggingface.co/yxchar/tlm-sciie-medium-scale)|[83.24](https://huggingface.co/yxchar/tlm-chemprot-medium-scale)|[87.28](https://huggingface.co/yxchar/tlm-rct-20k-medium-scale)|
|*Large*|[94.36](https://huggingface.co/yxchar/tlm-ag-large-scale)|[95.16](https://huggingface.co/yxchar/tlm-hyp-large-scale)|[72.49](https://huggingface.co/yxchar/tlm-amazon-large-scale)|[95.77](https://huggingface.co/yxchar/tlm-imdb-medium-scale)|[72.19](https://huggingface.co/yxchar/tlm-citation_intent-large-scale)|[83.29](https://huggingface.co/yxchar/tlm-sciie-large-scale)|[85.12](https://huggingface.co/yxchar/tlm-chemprot-large-scale)|[87.50](https://huggingface.co/yxchar/tlm-rct-20k-large-scale)|

The released models and datasets are compatible with [HuggingFace's Transformers](https://huggingface.co/transformers/) and [Datasets](https://huggingface.co/docs/datasets/index.html). We provide an example script to evaluate a model checkpoints on a certain task, run
```
bash example_scripts/evaluate.sh
```
To get the evaluation results for SciERC with a small-scale model.

### Training

We provide two example scripts to train a model from scratch, run
```
bash example_scripts/train.sh && bash example_scripts/finetune.sh
```
To train a small-scale model for SciERC. Here `example_scripts/train.sh` corresponds to the first stage training where the external data ratio and MLM weight are non-zero, and `example_scripts/finetune.sh` corresponds to the second training stage where no external data or self-supervised loss can be perceived by the model.

### Citation
Please cite our paper if you use TLM in your work:
```bibtex
@misc{yao2021tlm,
title={NLP From Scratch Without Large-Scale Pretraining: A Simple and Efficient Framework},
author={Yao, Xingcheng and Zheng, Yanan and Yang, Xiaocong and Yang, Zhilin},
year={2021}
}
```
9 changes: 9 additions & 0 deletions accelerate_config/example_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
fp16: true
machine_rank: 0
main_process_ip: null
main_process_port: 1234
main_training_function: main
num_machines: 1
num_processes: 1
9 changes: 9 additions & 0 deletions accelerate_config/example_dist_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
fp16: true
machine_rank: 0
main_process_ip: null
main_process_port: 1234
main_training_function: main
num_machines: 1
num_processes: 8
19 changes: 19 additions & 0 deletions example_scripts/evaluate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
TASK=sciie
SCALE=small

if [[ $TASK == "imdb" ]]
then
MAXLEN=512
else
MAXLEN=128
fi

accelerate launch --config_file ./accelerate_config/example_config.yaml src/run.py \
--max_train_steps 0 \
--preprocessing_num_workers 32 \
--max_length $MAXLEN \
--pad_to_max_length \
--model_name_or_path yxchar/tlm-${TASK}-${SCALE}-scale \
--config_dir yxchar/tlm-${TASK}-${SCALE}-scale \
--per_device_eval_batch_size 16 \
--task_name $TASK
42 changes: 42 additions & 0 deletions example_scripts/finetune.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
TASK=sciie
SCALE=small
OUTPUT_DIR=./results
SAVENAME=$TASK-$SCALE-scale
MLM_WEIGHT=0
EXTERNAL_RATIO=0
LR=2e-5
WD=0.00
WARMUP=3000

if [[ $TASK == "imdb" ]]
then
MAXLEN=512
else
MAXLEN=128
fi

accelerate launch --config_file ./accelerate_config/example_dist_config.yaml src/run.py \
--max_train_steps 30000 \
--steps_to_eval 1000 \
--steps_to_save 100000 \
--steps_to_log 100 \
--external_dataset_name small_external.csv \
--preprocessing_num_workers 32 \
--max_ckpts_to_keep 3 \
--max_length $MAXLEN \
--pad_to_max_length \
--config_dir ./models/$TASK-$SCALE-scale \
--model_name_or_path $OUTPUT_DIR/$SAVENAME/final \
--output_dir $OUTPUT_DIR/ft-$SAVENAME \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 16 \
--cuda_devices 0,1,2,3,4,5,6,7 \
--task_name $TASK \
--save_final \
--mlm_weight $MLM_WEIGHT \
--external_ratio $EXTERNAL_RATIO \
--weight_decay $WD \
--learning_rate $LR \
--num_warmup_steps $WARMUP \
--seed 0 \
--reset_cls
43 changes: 43 additions & 0 deletions example_scripts/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
TASK=sciie
SCALE=small
OUTPUT_DIR=./results
SAVENAME=$TASK-$SCALE-scale
MLM_WEIGHT=20
EXTERNAL_RATIO=999
LR=1e-4
WD=0.01
WARMUP=10000

if [[ $TASK == "imdb" ]]
then
MAXLEN=512
else
MAXLEN=128
fi

mkdir -p $OUTPUT_DIR

accelerate launch --config_file ./accelerate_config/example_dist_config.yaml src/run.py \
--max_train_steps 150000 \
--steps_to_eval 100000 \
--steps_to_save 50000 \
--steps_to_log 100 \
--external_dataset_name small_external.csv \
--preprocessing_num_workers 32 \
--max_length $MAXLEN \
--max_ckpts_to_keep 3 \
--pad_to_max_length \
--config_dir yxchar/tlm-${TASK}-${SCALE}-scale \
--from_scratch \
--output_dir $OUTPUT_DIR/$SAVENAME \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 16 \
--cuda_devices 0,1,2,3,4,5,6,7 \
--task_name $TASK \
--save_final \
--mlm_weight $MLM_WEIGHT \
--external_ratio $EXTERNAL_RATIO \
--mask_task \
--weight_decay $WD \
--learning_rate $LR \
--num_warmup_steps $WARMUP
Binary file added fig/framework.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
accelerate==0.5.1
datasets==1.8.0
pandas==1.1.5
transformers==4.8.1
torch==1.9.0
80 changes: 80 additions & 0 deletions src/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import random
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

import torch

from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase

@dataclass
class DataCollatorForLanguageModeling:

tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
pad_to_multiple_of: Optional[int] = None

def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)

def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# Handle dict or lists with proper padding and conversion to tensor.
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)

# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
batch.pop("id", None)

init_labels = batch.pop("labels", None)
input_ids = batch["input_ids"].clone()
if self.mlm:
batch["input_ids"], batch["labels"] = self.mask_tokens(
input_ids, special_tokens_mask=special_tokens_mask,
init_labels=init_labels,
)
else:
batch.pop("labels", None)
return batch

def mask_tokens(
self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None,
init_labels: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()

probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
if init_labels is not None:
labels[inputs == self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)] = init_labels

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
Loading

0 comments on commit 44e18ed

Please sign in to comment.