Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the SEW and SEW-D speech models #13962

Merged
merged 19 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Model conversion and updated no-mask tests
  • Loading branch information
anton-l committed Oct 15, 2021
commit 08307e32f3bb974c53526b0452d637e208f1da34
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SEWD | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/sew/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


_import_structure = {
".wav2vec2.feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
"configuration_sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"],
}

Expand All @@ -34,7 +33,6 @@
]

if TYPE_CHECKING:
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig

if is_torch_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import argparse
import json
import os
import numpy as np

import fairseq
import numpy as np
import torch
from fairseq.data import Dictionary

Expand Down Expand Up @@ -171,7 +171,7 @@ def convert_config(model):
fs_config = model.cfg

config.activation_dropout = fs_config.activation_dropout
config.apply_spec_augment = (fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0)
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
config.attention_dropout = fs_config.attention_dropout
config.conv_bias = fs_config.conv_bias
conv_layers = eval(fs_config.conv_feature_layers)
Expand Down Expand Up @@ -224,6 +224,15 @@ def convert_sew_checkpoint(
config = convert_config(model[0])
model = model[0].eval()

return_attention_mask = True if config.feat_extract_norm == "layer" else False
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=return_attention_mask,
)

if is_finetuned:
if dict_path:
target_dict = Dictionary.load(dict_path)
Expand All @@ -250,20 +259,13 @@ def convert_sew_checkpoint(
word_delimiter_token="|",
do_lower_case=False,
)
return_attention_mask = True if config.feat_extract_norm == "layer" else False
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=return_attention_mask,
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.save_pretrained(pytorch_dump_folder_path)

hf_model = SEWForCTC(config)
else:
hf_model = SEWModel(config)
feature_extractor.save_pretrained(pytorch_dump_folder_path)

recursively_load_weights(model, hf_model, is_finetuned)

Expand Down
13 changes: 1 addition & 12 deletions src/transformers/models/sew/modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,11 +768,6 @@ def __init__(self, config: SEWConfig):
self.config = config
self.feature_extractor = SEWFeatureExtractor(config)
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
if config.conv_dim[-1] != config.hidden_size:
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
else:
self.feature_projection = None

self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())

Expand Down Expand Up @@ -876,13 +871,7 @@ def forward(
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)

if self.feature_projection is not None:
hidden_states = self.feature_projection(extract_features)
hidden_states = self.feature_dropout(hidden_states)
else:
hidden_states = extract_features

hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
hidden_states = self._mask_hidden_states(extract_features, mask_time_indices=mask_time_indices)

encoder_outputs = self.encoder(
hidden_states,
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/sew_d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


_import_structure = {
".wav2vec2.feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
"configuration_sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"],
}

Expand All @@ -34,7 +33,6 @@
]

if TYPE_CHECKING:
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig

if is_torch_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def convert_config(model):
fs_config = model.cfg

config.activation_dropout = fs_config.activation_dropout
config.apply_spec_augment = (fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0)
config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0
config.attention_dropout = fs_config.attention_dropout
config.conv_bias = fs_config.conv_bias
conv_layers = eval(fs_config.conv_feature_layers)
Expand Down Expand Up @@ -235,6 +235,15 @@ def convert_sew_checkpoint(
config = convert_config(model[0])
model = model[0].eval()

return_attention_mask = True if config.feat_extract_norm == "layer" else False
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=return_attention_mask,
)

if is_finetuned:
if dict_path:
target_dict = Dictionary.load(dict_path)
Expand All @@ -261,20 +270,13 @@ def convert_sew_checkpoint(
word_delimiter_token="|",
do_lower_case=False,
)
return_attention_mask = True if config.feat_extract_norm == "layer" else False
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=return_attention_mask,
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.save_pretrained(pytorch_dump_folder_path)

hf_model = SEWDForCTC(config)
else:
hf_model = SEWDModel(config)
feature_extractor.save_pretrained(pytorch_dump_folder_path)

recursively_load_weights(model, hf_model, is_finetuned)

Expand Down
48 changes: 21 additions & 27 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,8 @@ def get_context(self):
return self.drop_prob


# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm, hidden_dropout_prob->activation_dropout
class DebertaV2SelfOutput(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaV2->SEWD, DebertaLayerNorm->LayerNorm, hidden_dropout_prob->activation_dropout
class SEWDSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
Expand Down Expand Up @@ -807,12 +807,12 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
return score


# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
class DebertaV2Attention(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->SEWD
class SEWDAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = DisentangledSelfAttention(config)
self.output = DebertaV2SelfOutput(config)
self.output = SEWDSelfOutput(config)
self.config = config

def forward(
Expand Down Expand Up @@ -844,8 +844,8 @@ def forward(
return attention_output


# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
class DebertaV2Intermediate(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->SEWD
class SEWDIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
Expand All @@ -861,7 +861,7 @@ def forward(self, hidden_states):


# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm, hidden_dropout_prob->activation_dropout
class DebertaV2Output(nn.Module):
class SEWDOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
Expand All @@ -876,13 +876,13 @@ def forward(self, hidden_states, input_tensor):
return hidden_states


# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
class DebertaV2Layer(nn.Module):
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->SEWD
class SEWDLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = DebertaV2Attention(config)
self.intermediate = DebertaV2Intermediate(config)
self.output = DebertaV2Output(config)
self.attention = SEWDAttention(config)
self.intermediate = SEWDIntermediate(config)
self.output = SEWDOutput(config)

def forward(
self,
Expand Down Expand Up @@ -948,14 +948,14 @@ def forward(self, hidden_states, residual_states, input_mask):
return output_states


# Copied from transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Encoder
class DebertaV2Encoder(nn.Module):
# Copied from transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Encoder with DebertaV2->SEWD
class SEWDTransformerEncoder(nn.Module):
"""Modified BertEncoder with relative position bias support"""

def __init__(self, config):
super().__init__()

self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([SEWDLayer(config) for _ in range(config.num_hidden_layers)])
self.relative_attention = getattr(config, "relative_attention", False)

if self.relative_attention:
Expand Down Expand Up @@ -1073,7 +1073,7 @@ def __init__(self, config):
self.config = config
self.pos_conv_embed = SEWDPositionalConvEmbedding(config)
self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
self.encoder = DebertaV2Encoder(config)
self.encoder = SEWDTransformerEncoder(config)
self.upsample = SEWDUpsampling(config)
self.gradient_checkpointing = False

Expand Down Expand Up @@ -1263,11 +1263,8 @@ def __init__(self, config: SEWDConfig):
self.config = config
self.feature_extractor = SEWDFeatureExtractor(config)
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
if config.conv_dim[-1] != config.hidden_size:
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
else:
self.feature_projection = None
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)

self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())

Expand Down Expand Up @@ -1371,11 +1368,8 @@ def forward(
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)

if self.feature_projection is not None:
hidden_states = self.feature_projection(extract_features)
hidden_states = self.feature_dropout(hidden_states)
else:
hidden_states = extract_features
hidden_states = self.feature_projection(extract_features)
hidden_states = self.feature_dropout(hidden_states)

hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)

Expand Down
33 changes: 16 additions & 17 deletions tests/test_modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _mock_init_weights(self, module):

@slow
def test_model_from_pretrained(self):
model = SEWModel.from_pretrained("../sew/checkpoints/sew-tiny-converted")
model = SEWModel.from_pretrained("anton-l/sew-tiny-100k")
self.assertIsNotNone(model)


Expand Down Expand Up @@ -424,27 +424,26 @@ def map_to_array(batch):
return ds["speech"][:num_samples]

def test_inference_pretrained_batched(self):
model = SEWModel.from_pretrained("../sew/checkpoints/sew-tiny-converted").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("../sew/checkpoints/sew-tiny-converted")
model = SEWModel.from_pretrained("anton-l/sew-tiny-100k").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/sew-tiny-100k")

input_speech = self._load_datasamples(2)

inputs = processor(input_speech, return_tensors="pt", padding=True)

input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)

with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
outputs = model(input_values).last_hidden_state

# expected outputs taken from the original SEW implementation
expected_outputs_first = torch.tensor(
[
[
[0.0373, 0.5522, 0.6186, -0.1686],
[-0.1658, 0.6962, 0.6115, -0.0962],
[0.0499, 0.8240, 0.4881, -0.1585],
[-0.1697, 0.8095, 0.4967, -0.1022],
[0.1509, 0.5372, 0.3061, -0.1694],
[-0.1700, 0.5764, 0.2753, -0.1299],
[0.1281, 0.7949, 0.2342, -0.1624],
[-0.1627, 0.6710, 0.2215, -0.1317],
],
[
[0.0408, 1.4355, 0.8605, -0.0968],
Expand All @@ -458,10 +457,10 @@ def test_inference_pretrained_batched(self):
expected_outputs_last = torch.tensor(
[
[
[0.5935, -0.0649, -0.0974, -0.0709],
[0.5365, -0.1670, -0.1518, -0.1130],
[0.5935, -0.0649, -0.0974, -0.0709],
[0.5365, -0.1670, -0.1518, -0.1130],
[1.3379, -0.1450, -0.1500, -0.0515],
[0.8364, -0.1680, -0.1248, -0.0689],
[1.2791, -0.1507, -0.1523, -0.0564],
[0.8208, -0.1690, -0.1199, -0.0751],
],
[
[0.6959, -0.0861, -0.1235, -0.0861],
Expand All @@ -472,17 +471,17 @@ def test_inference_pretrained_batched(self):
],
device=torch_device,
)
expected_output_sum = 66396.2656
expected_output_sum = 62146.7422

self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 1)
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 2)

@tooslow
def test_inference_ctc_batched(self):
# TODO: enable this test once the finetuned models are available
model = SEWForCTC.from_pretrained("../sew/checkpoints/sew-tiny-converted").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("../sew/checkpoints/sew-tiny-converted", do_lower_case=True)
model = SEWForCTC.from_pretrained("anton-l/sew-tiny-100k-ft-100h").to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("anton-l/sew-tiny-100k-ft-100h", do_lower_case=True)

input_speech = self._load_datasamples(2)

Expand Down
Loading