This repository contains data and code for our EMNLP 2021 paper: STraTA: Self-Training with Task Augmentation for Better Few-shot Learning. Our new implementation of STraTA typically yields better results than what reported in our paper.
Note: Our code can be used as a tool for automatic data labeling.
This repository is tested on Python 3.8+, PyTorch 1.10+, and the 🤗 Transformers 4.16+.
You should install all necessary Python packages in a virtual environment. If you are unfamiliar with Python virtual environments, please check out the user guide.
Below, we create a virtual environment with the Anaconda Python distribution and activate it.
conda create -n strata python=3.9
conda activate strata
Next, you need to install 🤗 Transformers. Please refer to 🤗 Transformers installation page for a detailed guide.
pip install transformers
Finally, install all necessary Python packages for our self-training algorithm.
pip install -r requirements.txt
This will install PyTorch as a backend.
The following example code shows how to run our self-training algorithm with a base model (e.g., BERT
, BERT
fine-tuned on MNLI
, BERT
produced by task augmentation) on the SciTail
science entailment dataset, which has two classes ['entails', 'neutral']
. We assume that you have a data directory that includes some training data (e.g., train.csv
), evaluation data (e.g., eval.csv
), and unlabeled data (e.g., infer.csv
).
import os
from selftraining import selftrain
data_dir = '/path/to/your/data/dir'
parameters_dict = {
'max_selftrain_iterations': 100,
'model_name_or_path': '/path/to/your/base/model', # could be the id of a model hosted by 🤗 Transformers
'output_dir': '/path/to/your/output/dir',
'train_file': os.path.join(data_dir, 'train.csv'),
'infer_file': os.path.join(data_dir, 'infer.csv'),
'eval_file': os.path.join(data_dir, 'eval.csv'),
'evaluation_strategy': 'steps',
'task_name': 'scitail',
'label_list': ['entails', 'neutral'],
'per_device_train_batch_size': 32,
'per_device_eval_batch_size': 8,
'max_length': 128,
'learning_rate': 2e-5,
'max_steps': 100000,
'eval_steps': 1,
'early_stopping_patience': 50,
'overwrite_output_dir': True,
'do_filter_by_confidence': False,
# 'confidence_threshold': 0.3,
'do_filter_by_val_performance': True,
'finetune_on_labeled_data': False,
'seed': 42,
}
selftrain(**parameters_dict)
Note: We checkpoint periodically during self-training. In case of preemptions, just re-run the above script and self-training will resume from the latest iteration.
If you have development data, you might want to tune some hyperparameters for self-training. Below are hyperparameters that could provide additional gains for your task.
finetune_on_labeled_data
: If set toTrue
, the resulting model from each self-training iteration is further fine-tuned on the original labeled data before the next self-training iteration. Intuitively, this would give the model a chance to "correct" ifself after being trained on pseudo-labeled data.do_filter_by_confidence
: If set toTrue
, the pseudo-labeled data in each self-training iteration is filtered based on the model confidence. For instance, ifconfidence_threshold
is set to0.3
, pseudo-labeled examples with a confidence score less than or equal to0.3
will be discarded. Note thatconfidence_threshold
should be greater or equal to1/num_labels
, wherenum_labels
is the number of class labels. Filtering out the lowest-confidence pseudo-labeled examples could be helpful in some cases.do_filter_by_val_performance
: If set toTrue
, the pseudo-labeled data in each self-training iteration is filtered based on the current validation performance. For instance, if your validation performance is 80% accuracy, you might want to get rid of 20% of the pseudo-labeled data with the lowest the confidence scores.
We strongly recommend distributed training with multiple accelerators. To activate distributed training, please try one of the following methods:
- Run
accelerate config
and answer to the questions asked. This will save adefault_config.yaml
file in your cache folder for 🤗 Accelerate. Now, you can run your script with the following command:
accelerate launch your_script.py --args_to_your_script
- Run your script with the following command:
python -m torch.distributed.launch --nnodes="{$NUM_NODES}" --nproc_per_node="{$NUM_TRAINERS}" --your_script.py --args_to_your_script
- Run your script with the following command:
torchrun --nnodes="{$NUM_NODES}" --nproc_per_node="{$NUM_TRAINERS}" --your_script.py --args_to_your_script
We recommend starting with a pre-trained BERT
model first to see how it performs on your task. Next, you might want to try self-training with a BERT
model fine-tuned on MNLI
(you could use our pre-trained models), i.e., fine-tuning BERT
on MNLI
before self-training it on your task. If MNLI
turns out to helpful for your task, you could possibly achieve better
performance by applying task augmentation to obtain a stronger base model for self-training.
We release the following T5
NLI data generation model checkpoints used in our paper:
T5
-3B-NLI-entailment (3 billion parameters)T5
-3B-NLI-neutral (3 billion parameters)T5
-3B-NLI-contradiction (3 billion parameters)T5
-3B-NLI-entailment_reversed (3 billion parameters)T5
-3B-NLI-neutral_reversed (3 billion parameters)T5
-3B-NLI-contradiction_reversed (3 billion parameters)
Note that our models were trained using a maximum sequence length of 128 for both the input and target sequences.
To obtain these models, we fine-tune the original T5-3B
model on MNLI
in a text-to-text format. Specifically, each MNLI
training example (sentA, sentB) → label
is cast as label: sentA → sentB
. The "reversed" models (with the suffix "-reversed") were trained on reversed examples label: sentB → sentA
. During inference, each model is fed a label
and a source_text
in the format label: input_text
as input (e.g., entailment: the facts are accessible to you
), and it generates some target_text
as output (e.g., you have access to the facts
).
Once inference is done, you need to create NLI examples as (input_text, target_text) → label
, or (target_text, input_text) → label
if you use a "reversed" model.
Please follow the T5
installation instructions to install T5
and set up accelerators on Google Cloud Platform. Then, take a look at the T5
decoding instructions to get an idea on how to produce predictions from one of our model checkpoints.
You need to prepare a text file inputs.txt
with one example per line, in the format label: input_text
(e.g., contradiction: his acting was really awful
).
The following example command generates 3 output samples per input using top-k sampling with k=5
:
t5_mesh_transformer \
--tpu="${TPU_NAME}" \
--gcp_project="${PROJECT}" \
--tpu_zone="${ZONE}" \
--model_dir="${MODEL_DIR}" \
--gin_file="${MODEL_DIR}/operative_config.gin" \
--gin_file="infer.gin" \
--gin_file="sample_decode.gin" \
--gin_param="input_filename = '/path/to/inputs.txt'"\
--gin_param="output_filename = '/path/to/outputs.txt'"\
--gin_param="utils.decode_from_file.repeats = 3" \ # number of output samples per input
--gin_param="utils.run.sequence_length = {'inputs': 128, 'targets': 128}" \
--gin_param="infer_checkpoint_step = '1065536'" \ # 1000000 pre-training steps + 65536 fine-tuning steps
--gin_param="utils.run.batch_size = ('sequences_per_batch', 64)" \
--gin_param="Bitransformer.decode.temperature = 1.0" \
--gin_param="Unitransformer.sample_autoregressive.temperature = 1.0" \
--gin_param="Unitransformer.sample_autoregressive.sampling_keep_top_k = 5" \ # top-k
--gin_param="utils.tpu_mesh_shape.tpu_topology = '${TPU_SIZE}'" \
Assume that the input file inputs.txt
has 10 examples, you should get an output file outputs.txt
with 30 output samples, where the 3i-2, 3i-1, 3i
^th output samples correspond to the i
^th input example (i=1,2,...,10
).
We recommend the following practices for task augmentation:
- Overgeneration. In our experiments, we perform overgeneration to get a large amount of synthetic NLI training data. We generate 100 output samples per input with
top-k (k = 40)
sampling. This could be expensive when you have a large amount of unlabeled data though. - Filtering. This is an important step to improve the quality of synthetic NLI training data. We use a
BERT
model fine-tuned onMNLI
in the original format as an NLI classifier to filter synthetic training examples (you could use our pre-trained models or any reliable NLI models available on 🤗 Models). We only keep an example if the NLI classifier's predicted probability exceeds a certain threshold. - Combining synthetic and realistic data. In our experiments, we use a two-stage training procedure where the model is first trained on the synthetic NLI data before being fine-tuned on the realistic
MNLI
data.
To facilitate your evaluation, we release the BERT
model checkpoints produced by task augmentation (TA) across datasets used in our few-shot experiments. Note that these models were trained on synthetic NLI data created using unlabeled texts from a target dataset. To avoid differences in evaluation methodology (e.g., training/development data subsets, number of random restarts, etc.), which can have a high impact on model performance in a low-data setting, you might want to perform our self-training algorithm on top of these model checkpoints using your own evaluation setup (e.g., data splits).
Please check out run.sh
to see how to perform our self-training algorithm with a BERT
Base model produced by task augmentation on the SciTail science entailment dataset using 8 labeled examples per class. Please turn off the debug mode by setting DEBUG_MODE_ON=False
. You can configure your training environment by specifying NUM_NODES
and NUM_TRAINERS
(number of processes per node). To launch the script, simply run source run.sh
. For your reference, below are the results we got with different development sets using distributed training on a single compute note with 4 NVIDIA GeForce GTX 1080 Ti GPUs.
Development file | # Development examples | Accuracy |
---|---|---|
eval_16.csv | 16 | 87.50 |
eval_256.csv | 256 | 92.97 |
eval.csv | 1304 | 92.15 |
What should I do if I do not have enough computational resources to run T5
to produce synthetic data?
In this case, you could fine-tune a model on an intermediate task (e.g., MNLI
or a closely related task to your task) before using it for self-training on your task. In our experiments, self-training on top of BERT
fine-tuned on MNLI
performs competitively with STraTA
in many cases.
If you extend or use this work, please cite the paper where it was introduced:
@inproceedings{vu-etal-2021-strata,
title = "{ST}ra{TA}: Self-Training with Task Augmentation for Better Few-shot Learning",
author = "Vu, Tu and
Luong, Minh-Thang and
Le, Quoc and
Simon, Grady and
Iyyer, Mohit",
booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
month = nov,
year = "2021",
address = "Online and Punta Cana, Dominican Republic",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2021.emnlp-main.462",
doi = "10.18653/v1/2021.emnlp-main.462",
pages = "5715--5731",
}
This is not an officially supported Google product.