Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opensource code for Deep Transformer with Latent Depth #2703

Closed
wants to merge 14 commits into from
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ We provide reference implementations of various sequence modeling papers:
- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
- [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
- [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
- [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
- **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
Expand All @@ -54,6 +55,7 @@ We provide reference implementations of various sequence modeling papers:

### What's New:

- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
- September 2020: [Added Linformer code](examples/linformer/README.md)
- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
- August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
Expand Down
77 changes: 77 additions & 0 deletions examples/latent_depth/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Deep Transformers with Latent Depth (Li et al., 2020)

[https://arxiv.org/abs/2009.13102] (https://arxiv.org/abs/2009.13102).

## Introduction

We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.

## Training a multilingual model with latent depth

Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)] (https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script] (https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided.
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>

python fairseq_cli/train.py ${databin_dir} \
--user-dir, examples/latent_depth/src \
--lang-pairs "${lang_pairs_str}" \
--arch multilingual_transformer_iwslt_de_en \
--task multilingual_translation_latent_depth \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--share-encoders \
--share-decoders \
--decoder-langtok \
--share-decoder-input-output-embed \
--dropout 0.3 --attention-dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \
--max-tokens 4096 --update-freq 1 \
--lr 0.0015 \
--clip-norm 1.0 \
--seed 2 \
--ddp-backend=no_c10d \
--encoder-layers 12 \
--decoder-layers 24 \
--decoder-latent-layer \
--sparsity-weight 0.1 \
--anneal-updates 5000 \
--soft-update 500 \
--target-layers 12 \
--share-weight 0.1
```
## Inference command

```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
model_path=<path to checkpoint>
src_lang=<source language to translate from>
tgt_lang=<target language to translate to>
gen_data=<name of data split, e.g. valid, test, etc>

python fairseq_cli/generate.py ${databin_dir} \
--path ${model_path} \
--task multilingual_translation_latent_depth \
--decoder-latent-layer \
--lang-pairs "${lang_pairs_str}" \
-s ${src_lang} -t ${tgt_lang} \
--gen-subset $gen_data \
--scoring sacrebleu \
--remove-bpe 'sentencepiece' \
--lenpen 1.0 \
--beam 5 \
--decoder-langtok \
--max-tokens 4096
```


## Citation
```bibtex
@article{li2020deep,
title={Deep Transformers with Latent Depth},
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang},
journal={arXiv preprint arXiv:2009.13102},
year={2020}
}
```
9 changes: 9 additions & 0 deletions examples/latent_depth/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .models import latent_multilingual_transformer # noqa
from .modules import latent_layers # noqa
from .loss import latent_depth # noqa
from . import multilingual_translation_latent_depth # noqa
Empty file.
86 changes: 86 additions & 0 deletions examples/latent_depth/src/loss/latent_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import math
from torch.nn.modules.loss import _Loss

class LatentLayersKLLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args

def forward(self, layer_samples, lang_idx, update_num, sample_size):
prior = self.args.prior
samples = layer_samples[lang_idx]
eps = 1e-7
if prior == "uniform":
# uniform prior
kl_loss = (samples * (
torch.log(samples + eps) - math.log(0.5)
)).sum(-1)
elif prior == "agged_posterior":
# aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm
kl_loss = (samples * (
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1)
else:
raise NotImplementedError("The specified prior is not implemented.")

# normalized by number of layers
kl_loss /= layer_samples[0].size()[0]
kl_weight = min(
self.args.sparsity_weight,
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates
)
kl_loss *= kl_weight * sample_size
return kl_loss

class LatentLayersSparsityLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args

def is_valid(self, update_num):
if self.args.target_layers <= 0:
return False
return update_num > (self.args.soft_update + self.args.anneal_updates)

def forward(self, layer_samples_list, update_num, sample_size):
batch_loss = 0
share_loss = 0
global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0)
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and
update_num > (self.args.soft_update + self.args.anneal_updates)):
# anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates)
* self.args.share_weight / self.args.anneal_updates
)
else:
weight_anneal = 1
# compute ratio among languages
layer_utilization = torch.sum(layer_samples, dim=0)
layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0:
# encouraging sharing across languages
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0)
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss
if self.args.target_layers > 0:
# computed expected number of layers selected
expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss
return batch_loss


Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from fairseq.models import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add copyright header

register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
base_architecture,
TransformerEncoder,
TransformerDecoder,
)
from fairseq.models.multilingual_transformer import MultilingualTransformerModel, base_multilingual_architecture

from .latent_transformer import (
LatentTransformerEncoder,
LatentTransformerDecoder,
)

@register_model('latent_multilingual_transformer')
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""Train Transformer models for multiple language pairs simultaneously.
TODO
"""
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder:
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs))
else:
return TransformerEncoder(args, lang_dict, embed_tokens)
else:
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
return LatentTransformerDecoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerDecoder(args, lang_dict, embed_tokens)

@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer')
def latent_multilingual_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 24)
args.share_encoders = getattr(args, 'share_encoders', True)
args.share_decoders = getattr(args, 'share_decoders', True)
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True)
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True)

base_architecture(args)
130 changes: 130 additions & 0 deletions examples/latent_depth/src/models/latent_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
from ..modules.latent_layers import LayerSelect
from torch import Tensor


class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_encoder_layer(args, idx)
for idx in range(args.encoder_layers)
])

def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx

def _build_encoder_layer(self, args, idx=None):
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)

def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
self.layer_select.sample(self.lang_idx)
return super().forward(src_tokens, src_lengths, return_all_hiddens)


class LatentTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.

Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerEncoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
self.layer_select = layer_select

def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)


class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
])

def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx

def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn)

def forward(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
self.layer_select.sample(self.lang_idx)
return super().forward(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
features_only=features_only,
alignment_layer=alignment_layer,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)


class LatentTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.

Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerDecoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).

"""
def __init__(
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx
self.layer_select = layer_select

def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
Empty file.
Loading