Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opensource code for Deep Transformer with Latent Depth #2703

Closed
wants to merge 14 commits into from
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ We provide reference implementations of various sequence modeling papers:
- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
- [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
- [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
- [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
- **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
Expand All @@ -54,6 +55,7 @@ We provide reference implementations of various sequence modeling papers:

### What's New:

- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
- September 2020: [Added Linformer code](examples/linformer/README.md)
- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
- August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
Expand Down
76 changes: 76 additions & 0 deletions examples/latent_depth/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Deep Transformers with Latent Depth (Li et al., 2020)

[https://arxiv.org/abs/2009.13102] (https://arxiv.org/abs/2009.13102).

## Introduction

We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.

## Training a multilingual model with latent depth

We use the same (preprocessed) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)] (https://github.com/cindyxinyiwang/multiDDS). Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages.
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>

fairseq-train ${databin_dir} \
--lang-pairs "${lang_pairs_str}" \
--arch multilingual_transformer_iwslt_de_en \
--task multilingual_translation \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--share-encoders \
--share-decoders \
--decoder-langtok \
--share-decoder-input-output-embed \
--dropout 0.3 --attention-dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \
--max-tokens 4096 --update-freq 1 \
--lr 0.0015 \
--clip-norm 1.0 \
--seed 2 \
--ddp-backend=no_c10d \
--encoder-layers 12 \
--decoder-layers 24 \
--decoder-latent-layer \
--sparsity-weight 0.1 \
--anneal-updates 5000 \
--soft-update 500 \
--target-layers 12 \
--share-weight 0.1
```
## Inference command

```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
model_path=<path to checkpoint>
src_lang=<source language to translate from>
tgt_lang=<target language to translate to>
gen_data=<name of data split, e.g. valid, test, etc>

fairseq-generate ${databin_dir} \
--path ${model_path} \
--task multilingual_translation \
--decoder-latent-layer \
--lang-pairs "${lang_pairs_str}" \
-s ${src_lang} -t ${tgt_lang} \
--gen-subset $gen_data \
--scoring sacrebleu \
--remove-bpe 'sentencepiece' \
--lenpen 1.0 \
--beam 5 \
--decoder-langtok \
--max-tokens 4096
```


## Citation
```bibtex
@article{li2020deep,
title={Deep Transformers with Latent Depth},
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang},
journal={arXiv preprint arXiv:2009.13102},
year={2020}
}
```
132 changes: 132 additions & 0 deletions fairseq/models/latent_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import torch
import torch.nn as nn
from fairseq.models import register_model, register_model_architecture
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder, TransformerModel
from fairseq.modules import LayerSelect, TransformerEncoderLayer, TransformerDecoderLayer
from torch import Tensor


class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_encoder_layer(args, idx)
for idx in range(args.encoder_layers)
])

def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx

def _build_encoder_layer(self, args, idx=None):
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)

def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
self.layer_select.sample(self.lang_idx)
return super().forward(src_tokens, src_lengths, return_all_hiddens)


class LatentTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.

Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerEncoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
self.layer_select = layer_select

def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)


class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
])

def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx

def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn)

def forward(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
self.layer_select.sample(self.lang_idx)
return super().forward(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
features_only=features_only,
alignment_layer=alignment_layer,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)


class LatentTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.

Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerDecoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).

"""
def __init__(
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx
self.layer_select = layer_select

def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)

14 changes: 12 additions & 2 deletions fairseq/models/multilingual_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
TransformerEncoder,
TransformerDecoder,
)
from fairseq.models.latent_transformer import (
LatentTransformerEncoder,
LatentTransformerDecoder,
)


@register_model('multilingual_transformer')
Expand Down Expand Up @@ -136,7 +140,10 @@ def get_encoder(lang):
encoder_embed_tokens = build_embedding(
task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path
)
lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens)
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
lang_encoders[lang] = LatentTransformerEncoder(args, task.dicts[lang], encoder_embed_tokens, num_logits=len(src_langs))
else:
lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens)
return lang_encoders[lang]

def get_decoder(lang):
Expand All @@ -147,7 +154,10 @@ def get_decoder(lang):
decoder_embed_tokens = build_embedding(
task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path
)
lang_decoders[lang] = TransformerDecoder(args, task.dicts[lang], decoder_embed_tokens)
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
lang_decoders[lang] = LatentTransformerDecoder(args, task.dicts[lang], decoder_embed_tokens, num_logits=len(tgt_langs))
else:
lang_decoders[lang] = TransformerDecoder(args, task.dicts[lang], decoder_embed_tokens)
return lang_decoders[lang]

# shared encoders/decoders (if applicable)
Expand Down
2 changes: 2 additions & 0 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .grad_multiply import GradMultiply
from .gumbel_vector_quantizer import GumbelVectorQuantizer
from .kmeans_vector_quantizer import KmeansVectorQuantizer
from .latent_layers import LayerSelect
from .layer_drop import LayerDropModuleList
from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
Expand Down Expand Up @@ -56,6 +57,7 @@
'KmeansVectorQuantizer',
'LayerDropModuleList',
'LayerNorm',
'LayerSelect',
'LearnedPositionalEmbedding',
'LightweightConv1dTBC',
'LightweightConv',
Expand Down
73 changes: 73 additions & 0 deletions fairseq/modules/latent_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


class LayerSelect(nn.Module):
"""Compute samples (from a Gumbel-Sigmoid distribution) which is used as
either (soft) weighting or (hard) selection of residual connection.

https://arxiv.org/abs/2009.13102
"""
def __init__(self, num_layers, num_logits, args):
super(LayerSelect, self).__init__()
self.args = args
self.layer_logits = torch.nn.Parameter(
torch.Tensor(num_logits, num_layers),
requires_grad=True,
)
self.hard_select = not (hasattr(args, "soft_select") and args.soft_select)
self.tau = getattr(args, "sampling_tau", 5)
self.detach_grad = False
self.layer_samples = [None] * num_logits

@staticmethod
def add_args(parser):
parser.add_argument(
'--soft-select',
action='store_true',
help='use soft samples in training an inference'
)
parser.add_argument('--sampling-tau', type=float, help='sampling temperature')

def sample(self, logit_idx):
""" To leverage the efficiency of distributed training, samples for all
layers are computed at once for each logit_idx. Logits are parameters
learnt independent of each other.

Args:
logit_idx: The index of logit paramters used for sampling.
"""
assert logit_idx is not None
self.samples = self._gumbel_sigmoid(
self.layer_logits[logit_idx, :].detach() if self.detach_grad \
else self.layer_logits[logit_idx, :],
dim=-1,
tau=self.tau,
hard=self.hard_select,
)
self.layer_samples[logit_idx] = self.samples

def forward(self,i):
sample = self.samples[i]
return sample

def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5):
# ~Gumbel(0,1)
gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
# Difference of two gumbels because we apply a sigmoid
gumbels1 = (logits + gumbels1 - gumbels2) / tau
y_soft = gumbels1.sigmoid()
if hard:
# Straight through.
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).masked_fill(y_soft > threshold, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret
Loading