-
Notifications
You must be signed in to change notification settings - Fork 6.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Opensource code for Deep Transformer with Latent Depth #2703
Closed
Closed
Changes from 11 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
3dcc0c6
Opensource code for Deep Transformer with Latent Depth
bdfc36d
address comments and nits
80e3b97
move latent depth code to examples/latent_depth/src and added test
8b0b454
fix previous commit
5095f69
move latent depth code to examples/latent_depth/src and added test
86c9351
add set_lang_idx in valid
c7c802e
fix for test failure in generate
efcd087
move everythin to examples/
ff3d67f
fix
3673b01
fix
72ac8f2
fix test config
65c0545
add header
de0aceb
fix test import
fe8e08b
lint
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Deep Transformers with Latent Depth (Li et al., 2020) | ||
|
||
[https://arxiv.org/abs/2009.13102] (https://arxiv.org/abs/2009.13102). | ||
|
||
## Introduction | ||
|
||
We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair. | ||
|
||
## Training a multilingual model with latent depth | ||
|
||
Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)] (https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script] (https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided. | ||
```bash | ||
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur" | ||
databin_dir=<path to binarized data> | ||
|
||
python fairseq_cli/train.py ${databin_dir} \ | ||
--user-dir, examples/latent_depth/src \ | ||
--lang-pairs "${lang_pairs_str}" \ | ||
--arch multilingual_transformer_iwslt_de_en \ | ||
--task multilingual_translation_latent_depth \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--share-encoders \ | ||
--share-decoders \ | ||
--decoder-langtok \ | ||
--share-decoder-input-output-embed \ | ||
--dropout 0.3 --attention-dropout 0.3 \ | ||
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ | ||
--lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \ | ||
--max-tokens 4096 --update-freq 1 \ | ||
--lr 0.0015 \ | ||
--clip-norm 1.0 \ | ||
--seed 2 \ | ||
--ddp-backend=no_c10d \ | ||
--encoder-layers 12 \ | ||
--decoder-layers 24 \ | ||
--decoder-latent-layer \ | ||
--sparsity-weight 0.1 \ | ||
--anneal-updates 5000 \ | ||
--soft-update 500 \ | ||
--target-layers 12 \ | ||
--share-weight 0.1 | ||
``` | ||
## Inference command | ||
|
||
```bash | ||
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur" | ||
databin_dir=<path to binarized data> | ||
model_path=<path to checkpoint> | ||
src_lang=<source language to translate from> | ||
tgt_lang=<target language to translate to> | ||
gen_data=<name of data split, e.g. valid, test, etc> | ||
|
||
python fairseq_cli/generate.py ${databin_dir} \ | ||
--path ${model_path} \ | ||
--task multilingual_translation_latent_depth \ | ||
--decoder-latent-layer \ | ||
--lang-pairs "${lang_pairs_str}" \ | ||
-s ${src_lang} -t ${tgt_lang} \ | ||
--gen-subset $gen_data \ | ||
--scoring sacrebleu \ | ||
--remove-bpe 'sentencepiece' \ | ||
--lenpen 1.0 \ | ||
--beam 5 \ | ||
--decoder-langtok \ | ||
--max-tokens 4096 | ||
``` | ||
|
||
|
||
## Citation | ||
```bibtex | ||
@article{li2020deep, | ||
title={Deep Transformers with Latent Depth}, | ||
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang}, | ||
journal={arXiv preprint arXiv:2009.13102}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .models import latent_multilingual_transformer # noqa | ||
from .modules import latent_layers # noqa | ||
from .loss import latent_depth # noqa | ||
from . import multilingual_translation_latent_depth # noqa |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
import math | ||
from torch.nn.modules.loss import _Loss | ||
|
||
class LatentLayersKLLoss(_Loss): | ||
def __init__(self, args): | ||
super().__init__() | ||
self.args = args | ||
|
||
def forward(self, layer_samples, lang_idx, update_num, sample_size): | ||
prior = self.args.prior | ||
samples = layer_samples[lang_idx] | ||
eps = 1e-7 | ||
if prior == "uniform": | ||
# uniform prior | ||
kl_loss = (samples * ( | ||
torch.log(samples + eps) - math.log(0.5) | ||
)).sum(-1) | ||
elif prior == "agged_posterior": | ||
# aggregated posterior | ||
y_t = torch.stack([x.detach() for x in layer_samples], dim=0) | ||
agged_q = torch.sum(y_t, dim=0) | ||
row_norm = agged_q.sum(-1) | ||
normed_agg_q = agged_q / row_norm | ||
kl_loss = (samples * ( | ||
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1) | ||
else: | ||
raise NotImplementedError("The specified prior is not implemented.") | ||
|
||
# normalized by number of layers | ||
kl_loss /= layer_samples[0].size()[0] | ||
kl_weight = min( | ||
self.args.sparsity_weight, | ||
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates | ||
) | ||
kl_loss *= kl_weight * sample_size | ||
return kl_loss | ||
|
||
class LatentLayersSparsityLoss(_Loss): | ||
def __init__(self, args): | ||
super().__init__() | ||
self.args = args | ||
|
||
def is_valid(self, update_num): | ||
if self.args.target_layers <= 0: | ||
return False | ||
return update_num > (self.args.soft_update + self.args.anneal_updates) | ||
|
||
def forward(self, layer_samples_list, update_num, sample_size): | ||
batch_loss = 0 | ||
share_loss = 0 | ||
global_sparsity_loss = 0 | ||
layer_samples = torch.stack(layer_samples_list, dim=0) | ||
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and | ||
update_num > (self.args.soft_update + self.args.anneal_updates)): | ||
# anneal sparsity weight | ||
if update_num < (self.args.anneal_updates + self.args.soft_update): | ||
weight_anneal = 0 | ||
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update): | ||
weight_anneal = ( | ||
(update_num - self.args.soft_update - self.args.anneal_updates) | ||
* self.args.share_weight / self.args.anneal_updates | ||
) | ||
else: | ||
weight_anneal = 1 | ||
# compute ratio among languages | ||
layer_utilization = torch.sum(layer_samples, dim=0) | ||
layer_utilization /= layer_samples.size()[0] | ||
if self.args.share_weight > 0: | ||
# encouraging sharing across languages | ||
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0) | ||
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss | ||
if self.args.target_layers > 0: | ||
# computed expected number of layers selected | ||
expeted_layers = sum(layer_utilization) | ||
# compute l2 loss wrt target number of layers | ||
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2 | ||
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss | ||
return batch_loss | ||
|
||
|
Empty file.
52 changes: 52 additions & 0 deletions
52
examples/latent_depth/src/models/latent_multilingual_transformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from fairseq.models import ( | ||
register_model, | ||
register_model_architecture, | ||
) | ||
from fairseq.models.transformer import ( | ||
base_architecture, | ||
TransformerEncoder, | ||
TransformerDecoder, | ||
) | ||
from fairseq.models.multilingual_transformer import MultilingualTransformerModel, base_multilingual_architecture | ||
|
||
from .latent_transformer import ( | ||
LatentTransformerEncoder, | ||
LatentTransformerDecoder, | ||
) | ||
|
||
@register_model('latent_multilingual_transformer') | ||
class LatentMultilingualTransformerModel(MultilingualTransformerModel): | ||
"""Train Transformer models for multiple language pairs simultaneously. | ||
TODO | ||
""" | ||
@classmethod | ||
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): | ||
if is_encoder: | ||
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer: | ||
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs)) | ||
else: | ||
return TransformerEncoder(args, lang_dict, embed_tokens) | ||
else: | ||
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer: | ||
return LatentTransformerDecoder( | ||
args, lang_dict, embed_tokens, num_logits=len(langs) | ||
) | ||
else: | ||
return TransformerDecoder(args, lang_dict, embed_tokens) | ||
|
||
@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer') | ||
def latent_multilingual_architecture(args): | ||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) | ||
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) | ||
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) | ||
args.encoder_layers = getattr(args, 'encoder_layers', 12) | ||
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) | ||
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) | ||
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) | ||
args.decoder_layers = getattr(args, 'decoder_layers', 24) | ||
args.share_encoders = getattr(args, 'share_encoders', True) | ||
args.share_decoders = getattr(args, 'share_decoders', True) | ||
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True) | ||
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True) | ||
|
||
base_architecture(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
import torch.nn as nn | ||
from fairseq.models.fairseq_encoder import EncoderOut | ||
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder | ||
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer | ||
from ..modules.latent_layers import LayerSelect | ||
from torch import Tensor | ||
|
||
|
||
class LatentTransformerEncoder(TransformerEncoder): | ||
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in | ||
TransformerEncoder. | ||
""" | ||
def __init__(self, args, dictionary, embed_tokens, num_logits=1): | ||
self.num_logits = num_logits | ||
self.num_layers = args.encoder_layers | ||
super().__init__(args, dictionary, embed_tokens) | ||
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) | ||
self.lang_idx = None | ||
self.layers = nn.ModuleList([ | ||
self._build_encoder_layer(args, idx) | ||
for idx in range(args.encoder_layers) | ||
]) | ||
|
||
def set_lang_idx(self, lang_idx): | ||
self.lang_idx = lang_idx | ||
|
||
def _build_encoder_layer(self, args, idx=None): | ||
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select) | ||
|
||
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): | ||
self.layer_select.sample(self.lang_idx) | ||
return super().forward(src_tokens, src_lengths, return_all_hiddens) | ||
|
||
|
||
class LatentTransformerEncoderLayer(TransformerEncoderLayer): | ||
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli | ||
or Gumbel Signmoid samples. | ||
|
||
Args: | ||
args (argparse.Namespace): parsed command-line arguments from standard | ||
TransformerEncoderLayer. | ||
idx (int): layer index (used to retrieve samples). | ||
layer_select (LayerSelect, optional): instance of LayerSelect module with logits | ||
parameters and sampling method. | ||
""" | ||
def __init__(self, args, idx, layer_select=None): | ||
super().__init__(args) | ||
self.idx = idx | ||
self.layer_select = layer_select | ||
|
||
def residual_connection(self, x, residual): | ||
return residual + x * self.layer_select(self.idx) | ||
|
||
|
||
class LatentTransformerDecoder(TransformerDecoder): | ||
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in | ||
TransformerDecoder. | ||
""" | ||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1): | ||
self.num_logits = num_logits | ||
self.num_layers = args.decoder_layers | ||
super().__init__( | ||
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn | ||
) | ||
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) | ||
self.lang_idx = None | ||
self.layers = nn.ModuleList([ | ||
self._build_decoder_layer(args, no_encoder_attn, idx) | ||
for idx in range(args.decoder_layers) | ||
]) | ||
|
||
def set_lang_idx(self, lang_idx): | ||
self.lang_idx = lang_idx | ||
|
||
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None): | ||
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn) | ||
|
||
def forward( | ||
self, | ||
prev_output_tokens, | ||
encoder_out: Optional[EncoderOut] = None, | ||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | ||
features_only: bool = False, | ||
alignment_layer: Optional[int] = None, | ||
alignment_heads: Optional[int] = None, | ||
src_lengths: Optional[Any] = None, | ||
return_all_hiddens: bool = False, | ||
): | ||
self.layer_select.sample(self.lang_idx) | ||
return super().forward( | ||
prev_output_tokens=prev_output_tokens, | ||
encoder_out=encoder_out, | ||
incremental_state=incremental_state, | ||
features_only=features_only, | ||
alignment_layer=alignment_layer, | ||
src_lengths=src_lengths, | ||
return_all_hiddens=return_all_hiddens, | ||
) | ||
|
||
|
||
class LatentTransformerDecoderLayer(TransformerDecoderLayer): | ||
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli | ||
or Gumbel Signmoid samples. | ||
|
||
Args: | ||
args (argparse.Namespace): parsed command-line arguments from standard | ||
TransformerDecoderLayer. | ||
idx (int): layer index (used to retrieve samples). | ||
layer_select (LayerSelect, optional): instance of LayerSelect module with logits | ||
parameters and sampling method. | ||
no_encoder_attn (bool, optional): whether to attend to encoder outputs | ||
(default: False). | ||
|
||
""" | ||
def __init__( | ||
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False | ||
): | ||
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn) | ||
self.idx = idx | ||
self.layer_select = layer_select | ||
|
||
def residual_connection(self, x, residual): | ||
return residual + x * self.layer_select(self.idx) |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add copyright header