Skip to content

Commit

Permalink
attn head selection
Browse files Browse the repository at this point in the history
Summary:
Add scripts for multihead attention selection in multilingual and multil-domain training from the following paper:
"Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling", NeurIPS 2021.

Reviewed By: yuntang

Differential Revision: D31802221

fbshipit-source-id: 8c69b89bda29e6857bd3af02979c07e1b5cf49f1
  • Loading branch information
hygong-fb authored and facebook-github-bot committed Jan 19, 2022
1 parent a075481 commit a59cea5
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 33 deletions.
13 changes: 5 additions & 8 deletions examples/attention_head_selection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,10 @@ fairseq-generate ${data_dir} \

## Citation
```bibtex
@article{gong2021attn,
title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling},
author={Hongyu Gong and
Yun Tang and
Juan Miguel Pino and
Xian Li},
journal={arXiv preprint arXiv:2106.10840},
year={2021}
@article{gong2021pay,
title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling},
author={Gong, Hongyu and Tang, Yun and Pino, Juan and Li, Xian},
journal={arXiv preprint arXiv:2106.10840},
year={2021}
}
'''
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fairseq.data import (
ConcatDataset,
Dictionary,
FairseqDataset,
ResamplingDataset
)
from fairseq.data.audio.data_cfg import S2TDataConfig
Expand Down Expand Up @@ -190,9 +191,9 @@ def _from_tsv(
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
src_lang_map,
tgt_lang_map,
domain_map
src_lang_map: Dict[str, int],
tgt_lang_map: Dict[str, int],
domain_map: Dict[str, int]
) -> SpeechToTextDatasetItemWithDomain:
samples = cls._load_samples_from_tsv(
root, split, src_lang_map,
Expand All @@ -215,11 +216,11 @@ def from_tsv(
is_train_split: bool,
epoch: int,
seed: int,
n_frames_per_step: int,
speaker_to_id,
src_lang_map: Dict[str, int],
tgt_lang_map: Dict[str, int],
domain_map: Dict[str, int]
domain_map: Dict[str, int],
n_frames_per_step: int = 1,
speaker_to_id=None
) -> SpeechToTextDatasetWithDomain:
datasets = [
cls._from_tsv(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,14 @@ def add_args(parser):
type=int,
help="total number of encoder attention heads"
)
# parser.add_argument(
# "--encoder-tasks",
# type=int,
# help="the number of encoder tasks (input languages or input domains)"
# )
# decoder self attention
# decoder self attention selection
parser.add_argument(
"--decoder-self-attn-head-select",
action="store_true",
default=False,
help="decoder self-attention head selection"
)
# decoder-encoder attention
# decoder-encoder attention selection
parser.add_argument(
"--dec-enc-attn-head-select",
action="store_true",
Expand All @@ -74,11 +69,6 @@ def add_args(parser):
type=int,
help="total number of decoder attention heads"
)
# parser.add_argument(
# "--decoder-tasks",
# type=int,
# help="the number of decoder tasks (output languages or output domains)"
# )
# selection strategy
parser.add_argument(
"--attn-head-select-strategy",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

from fairseq.utils import safe_getattr

from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
from ..modules.multihead_attention_selection import MultiheadAttentionSelection

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

from typing import Dict, Optional, Tuple

import torch
from fairseq import utils
from fairseq.modules.quant_noise import quant_noise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def multi_head_attention_forward(
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"

if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
is_train_split=is_train_split,
epoch=epoch,
seed=self.args.seed,
speaker_to_id=self.speaker_to_id,
src_lang_map=self.src_lang_map,
tgt_lang_map=self.tgt_lang_map,
domain_map=self.domain_map
domain_map=self.domain_map,
speaker_to_id=self.speaker_to_id
)

def build_model(self, args):
Expand Down

0 comments on commit a59cea5

Please sign in to comment.