Skip to content

rycolab/differentiable-subset-pruning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 

Repository files navigation

differentiable-subset-pruning

This repository is in accompany with the paper: Differentiable Subset Pruning of Transformer Heads.

Dependencies

  • python 3.7.4
  • perl 5.18.4
  • pytorch 1.7.1+cu101

BERT on MNLI

Install the Transformer library adapted from Huggingface:

cd transformers
pip install -e .

Then install other dependencies:

cd examples
pip install -r requirements.txt

Download the MNLI dataset:

cd pruning
python ../../utils/download_glue_data.py --data_dir data/ --task MNLI

Joint Pruning

Joint DSP

export TASK_NAME=MNLI
python run_dsp.py \
	--data_dir data/$TASK_NAME/ \
	--model_name_or_path bert-base-uncased \
	--task_name $TASK_NAME \
	--do_train \
	--do_eval \
	--max_seq_length 128 \
	--per_device_train_batch_size 32 \
	--learning_rate 2e-5 \
	--num_train_epochs 3.0 \
	--output_dir dsp_output/ \
	--cache_dir cache/ \
	--save_steps 10000 \
	--num_of_heads 12 \
	--pruning_lr 0.5 \
	--annealing \
	--initial_temperature 1000 \
	--final_temperature 1e-8 \
	--cooldown_steps 25000 \
	--joint_pruning 

where num_of_heads is the number of attention heads you want to keep unpruned, pruning_lr, initial_temperature, final_temperature, and cooldown_steps are the hyperparameters for DSP.

You can also use straight-through estimator:

export TASK_NAME=MNLI
python run_dsp.py \
	--data_dir data/$TASK_NAME/ \
	--model_name_or_path bert-base-uncased \
	--task_name $TASK_NAME \
	--do_train \
	--do_eval \
	--max_seq_length 128 \
	--per_device_train_batch_size 32 \
	--learning_rate 2e-5 \
	--num_train_epochs 3.0 \
	--output_dir dsp_output/ \
	--cache_dir cache/ \
	--save_steps 10000 \
	--num_of_heads 12 \
	--pruning_lr 0.5 \
    	--use_ste \
	--joint_pruning 

Voita et al.

The baseline of Voita et al., 2019 can be reproduced by

export TASK_NAME=MNLI
python run_voita.py \
	--model_name_or_path bert-base-uncased \
	--task_name $TASK_NAME \
	--do_train \
	--do_eval \
	--data_dir data/$TASK_NAME/ \
	--max_seq_length 128 \
	--per_device_train_batch_size 32 \
	--learning_rate 2e-5 \
	--num_train_epochs 3.0 \
	--output_dir voita_output/\
	--cache_dir cache/ \
	--save_steps 10000 \
	--l0_penalty 0.0015 \
	--joint_pruning \
	--pruning_lr 0.1 

where l0_penalty is used to indirectly control the number of heads.

Pipelined Pruning

We need to fine tune BERT on MNLI first:

export TASK_NAME=MNLI
python ../text-classification/run_glue.py \
	--model_name_or_path bert-base-uncased \
	--task_name $TASK_NAME \
	--do_train \
	--do_eval \
	--data_dir data/$TASK_NAME/ \
	--max_seq_length 128 \
	--per_device_train_batch_size 32 \
	--learning_rate 2e-5 \
	--num_train_epochs 3.0 \
	--output_dir output/ \
	--cache_dir cache/ \
	--save_steps 10000 

Pipelined DSP

export TASK_NAME=MNLI
python run_dsp.py \
	--data_dir data/$TASK_NAME/ \
	--model_name_or_path output/ \
	--task_name $TASK_NAME \
	--do_train \
	--do_eval \
	--max_seq_length 128 \
	--per_device_train_batch_size 32 \
	--learning_rate 2e-5 \
	--num_train_epochs 1.0 \
	--output_dir dsp_output/pipelined/ \
	--cache_dir cache/ \
	--save_steps 10000 \
	--num_of_heads 12 \
	--pruning_lr 0.5 \
	--annealing \
	--initial_temperature 1000 \
	--final_temperature 1e-8 \
	--cooldown_steps 8333 

Michel et al.

The baseline of Michel et al., 2019 can be reproduced by

export TASK_NAME=MNLI
python run_michel.py \
	--data_dir data/$TASK_NAME/ \
	--model_name_or_path output/ \
	--task_name $TASK_NAME \
	--output_dir michel_output/ \
	--cache_dir cache/ \
	--tokenizer_name bert-base-uncased \
	--exact_pruning \
    	--num_of_heads 12 

where num_of_heads is the number of heads to prune (mask) at each step.

Enc-Dec on IWSLT

Install the fairseq library adapted from Fairseq:

cd fairseq
pip install -e .

Download and prepare the IWSLT dataset:

cd examples/pruning
./prepare-iwslt.sh

Preprocess the dataset:

export TEXT=iwslt14.tokenized.de-en
fairseq-preprocess --source-lang de --target-lang en \
    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
    --destdir data-bin/\
    --workers 20

Joint Training

Joint DSP

export SAVE_DIR=dsp_checkpoints/

python run_dsp.py \
    data-bin/ \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --max-epoch 60 \
    --eval-bleu \
    --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \
    --save-dir $SAVE_DIR/ \
    --num-of-heads 4 \
    --pruning-lr 0.2 \
    --joint-pruning \
    --annealing \
    --initial-temperature 0.1 \
    --final-temperature 1e-8 \
    --cooldown-steps 15000 

python generate_dsp.py \
	data-bin \
 	-s de -t en \
	--path $SAVE_DIR/checkpoint_best.pt \
	--quiet \
	--batch-size 32 --beam 5 --remove-bpe

where num_of_heads is the number of attention heads you want to keep unpruned, pruning-lr, initial-temperature, final-temperature, and cooldown-steps are the hyperparameters for DSP.

You can also use straight-through estimator:

export SAVE_DIR=dsp_checkpoints/

python run_dsp.py \
    data-bin/ \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --max-epoch 60 \
    --eval-bleu \
    --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \
    --save-dir $SAVE_DIR/ \
    --num-of-heads 4 \
    --pruning-lr 0.2 \
    --joint-pruning \
    --use-ste 

python generate_dsp.py \
	data-bin \
 	-s de -t en \
	--path $SAVE_DIR/checkpoint_best.pt \
	--quiet \
	--batch-size 32 --beam 5 --remove-bpe

Voita et al.

The baseline of Voita et al., 2019 can be reproduced by

export SAVE_DIR=voita_checkpoints/

python run_voita.py \
		data-bin/ \
		--arch transformer_iwslt_de_en --share-decoder-input-output-embed \
		--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
		--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
		--dropout 0.3 --weight-decay 0.0001 \
		--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
		--max-tokens 4096 \
		--max-epoch 60 \
		--keep-best-checkpoints 1 \
		--l0-penalty 200 \
		--save-dir $SAVE_DIR/ 

python generate_voita.py \
	data-bin \
	-s de -t en \
	--path $SAVE_DIR/checkpoint_last.pt \
	--quiet \
	--batch-size 32 --beam 5 --remove-bpe

where l0_penalty is used to indirectly control the number of heads.

Pipelined Pruning

We need to fine tune the model on IWSLT first:

fairseq-train \
    data-bin/ \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --max-epoch 60 \
    --save-dir checkpoints/ 

Pipelined DSP

export SAVE_DIR=dsp_checkpoints/pipelined/

python run_dsp.py \
    data-bin/ \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --reset-optimizer \
    --restore-file checkpoints/checkpoint_last.pt \
    --max-epoch 61 \
    --save-dir $SAVE_DIR/ \
    --num-of-heads 4 \
    --pruning-lr 0.2 \
    --annealing \
    --initial-temperature 0.1 \
    --final-temperature 1e-8 \
    --cooldown-steps 250 

python generate_dsp.py \
	data-bin \
 	-s de -t en \
	--path $SAVE_DIR/checkpoint_last.pt \
	--quiet \
	--batch-size 32 --beam 5 --remove-bpe

Michel et al.

The baseline of Michel et al., 2019 can be reproduced by

python run_michel.py \
	data-bin \
 	-s de -t en \
	-a transformer_iwslt_de_en \
 	--restore-file checkpoints/checkpoint_last.pt \
	--optimizer adam --adam-betas '(0.9,0.98)' \
	--reset-optimizer \
	--batch-size 32 --beam 5 --remove-bpe \
	--num-of-heads 4 \
	--exact-pruning 

where num_of_heads is the number of heads to prune (mask) at each step.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published