Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new meta w2v2-conformer BERT-like model #28165

Merged
merged 77 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
96bd9b7
first commit
ylacombe Dec 19, 2023
d45e075
correct default value non causal
ylacombe Dec 19, 2023
35193ab
update config and modeling code
ylacombe Dec 20, 2023
730ee47
update converting checkpoint
ylacombe Dec 20, 2023
ad75260
clean modeling and fix tests
ylacombe Dec 20, 2023
95c6542
make style
ylacombe Dec 20, 2023
7f52d83
Merge branch 'huggingface:main' into add-w2v2-shaw
ylacombe Dec 20, 2023
175f4d3
add new config parameters to docstring
ylacombe Dec 20, 2023
40c3357
fix copied from statements
ylacombe Dec 20, 2023
29e20ea
Apply suggestions from code review
ylacombe Dec 21, 2023
07b34ff
make position_embeddings_type docstrings clearer
ylacombe Dec 21, 2023
c1f6ac4
clean converting script
ylacombe Dec 21, 2023
9792399
remove function not used
ylacombe Dec 21, 2023
d267a7c
clean modeling file
ylacombe Dec 21, 2023
d69d850
apply suggestion for test file + add convert script to not_doctested
ylacombe Dec 21, 2023
ca23f4e
modify tests according to review - cleaner logic and more tests
ylacombe Jan 1, 2024
c7550cd
Apply nit suggestions from code review
ylacombe Jan 1, 2024
1b132c6
add checker of valid position embeddings type
ylacombe Jan 1, 2024
e179499
instantiate new layer norm layer with the right eps
ylacombe Jan 1, 2024
4848a26
fix freeze_feature_encoder since it can be None in some cases
ylacombe Jan 1, 2024
0a8e2a9
add test same output in convert script
ylacombe Jan 1, 2024
5d77490
restore wav2vec2conformer and add new model
ylacombe Jan 3, 2024
b5b0bd2
create processor and FE + clean
ylacombe Jan 3, 2024
aecc8fe
add new model code
ylacombe Jan 4, 2024
1d78100
fix convert script and set default config parameters
ylacombe Jan 4, 2024
f98e5ab
correct model id paths
ylacombe Jan 4, 2024
b8a386a
make style
ylacombe Jan 4, 2024
580fa3e
make fix-copies and cleaning files
ylacombe Jan 4, 2024
a847b08
fix copied from statements
ylacombe Jan 4, 2024
0ededce
complete .md and fixe copies
ylacombe Jan 4, 2024
c0e92e2
clean convert script argument defaults
ylacombe Jan 4, 2024
43dd379
fix config parameters docstrings
ylacombe Jan 4, 2024
61162a7
fix config docstring
ylacombe Jan 4, 2024
ae70620
Merge branch 'huggingface:main' into add-w2v2-shaw
ylacombe Jan 4, 2024
ad0b1a9
add copied from and enrich FE tests
ylacombe Jan 4, 2024
1d37cbb
fix copied from and repo-consistency
ylacombe Jan 4, 2024
73bf286
add autotokenizer
ylacombe Jan 5, 2024
ca5ec51
make test input length shorter and change docstring code
ylacombe Jan 5, 2024
dbc53ab
fix docstrings and copied from
ylacombe Jan 5, 2024
b18ad62
add add_adapter to ASR training example
ylacombe Jan 5, 2024
7d85679
make testing of adapters more robust
ylacombe Jan 8, 2024
2ea024d
adapt to multi adapter layers
ylacombe Jan 8, 2024
bcdd067
refactor input_values->input_features and remove w2v2-bert feature ex…
ylacombe Jan 10, 2024
a17d2fa
remove pretraining model
ylacombe Jan 10, 2024
2e6063d
remove depreciated features and useless lines
ylacombe Jan 10, 2024
78ef614
add copied from and ignore statements to modeling tests
ylacombe Jan 10, 2024
8a83831
remove pretraining model #2
ylacombe Jan 10, 2024
22c552f
change import in convert script
ylacombe Jan 10, 2024
9115615
change default in convert script
ylacombe Jan 10, 2024
2b57096
update readme and remove useless line
ylacombe Jan 10, 2024
7db6c40
Update tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py
ylacombe Jan 15, 2024
87876a5
refactor BERT to Bert for consistency
ylacombe Jan 15, 2024
aaf64c1
remove useless ignore copy statement
ylacombe Jan 15, 2024
39b4ef1
add persistent to buffer in rotary
ylacombe Jan 15, 2024
7540037
add eps in LayerNorm init and remove copied from
ylacombe Jan 15, 2024
14f5fad
add adapter activation parameters and add copied from statements
ylacombe Jan 15, 2024
85e8865
Fix copied statements and add unitest.skip reasons
ylacombe Jan 15, 2024
2efa94b
add copied statement in test_processor
ylacombe Jan 15, 2024
42ee5ae
refactor processor
ylacombe Jan 15, 2024
f347b14
make style
ylacombe Jan 15, 2024
632272a
replace numpy random by torch rand
ylacombe Jan 15, 2024
0876dc9
remove expected output CTC
ylacombe Jan 15, 2024
8aa7bdf
Merge branch 'huggingface:main' into add-w2v2-shaw
ylacombe Jan 15, 2024
2f4cff3
improve converting script with processor class
ylacombe Jan 15, 2024
1856710
Apply suggestions from code review
ylacombe Jan 15, 2024
a0e5ca2
remove gumbel class
ylacombe Jan 15, 2024
b5e7d70
remove tests related to previously deleted class
ylacombe Jan 15, 2024
fc6dfc2
Update src/transformers/models/wav2vec2_bert/configuration_wav2vec2_b…
ylacombe Jan 15, 2024
ef1ad91
correct typos
ylacombe Jan 15, 2024
f3d00e6
remove uused parameters
ylacombe Jan 15, 2024
849f3f5
update processor to takes both text and audio
ylacombe Jan 16, 2024
7266710
update checkpoints
ylacombe Jan 16, 2024
07de02e
update expected output and add ctc expected output
ylacombe Jan 16, 2024
37dd941
add label_attention_mask
ylacombe Jan 17, 2024
8b37745
replace pt with np in processor tests
ylacombe Jan 17, 2024
7160906
fix typo
ylacombe Jan 17, 2024
c6ec6b1
revert to behaviour with labels_attention_mask
ylacombe Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,6 @@ def __init__(self, config, use_position_embeddings=True):
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))

# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention.forward
ylacombe marked this conversation as resolved.
Show resolved Hide resolved
def forward(
self,
hidden_states: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,29 @@ class Wav2Vec2ConformerConfig(PretrainedConfig):
Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
if `add_adapter is True`.
position_embeddings_type (`str`, *optional*, defaults to `"relative"`):
Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left
`None` no relative position embedding is applied.
Can be specified to `relative`, `relative_key` or `rotary` for relative, relative as defined by Shaw, or rotary position embeddings respectively. If left
ylacombe marked this conversation as resolved.
Show resolved Hide resolved
`None` no relative position embedding is applied. For more information on `"relative_key"`, please refer to [Self-Attention
with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
rotary_embedding_base (`int`, *optional*, defaults to 10000):
If `"rotary"` position embeddings are used, defines the size of the embedding base.
max_source_positions (`int`, *optional*, defaults to 5000):
if `"relative"` position embeddings are used, defines the maximum source input positions.
left_max_position_embeddings (`int`, *optional*, defaults to 64):
if `"_key"` position embeddings are used, defines the left clipping value for relative positions.
ylacombe marked this conversation as resolved.
Show resolved Hide resolved
right_max_position_embeddings (`int`, *optional*, defaults to 8):
if `"_key"` position embeddings are used, defines the right clipping value for relative positions.
ylacombe marked this conversation as resolved.
Show resolved Hide resolved
conv_depthwise_kernel_size (`int`, defaults to 31):
Kernel size of convolutional depthwise 1D layer in Conformer blocks.
conformer_conv_dropout (`float`, defaults to 0.1):
The dropout probability for all convolutional layers in Conformer blocks.
non_causal_depth_wise_conv (`bool`, defaults to `True`):
If `False`, use causal convolutional depthwise layers in Conformer blocks.
skip_feature_encoder (`bool`, defaults to `False`):
Whether to skip the feature encoder layers. Only relevant when using a feature extractor that computes spectrogram-like inputs instead of raw waveforms.
skip_encoder_layer_norm (`bool`, defaults to `False`):
Whether to skip the input layer norm of the encoder.
skip_pos_conv_embed (`bool`, defaults to `False`):
Whether to skip the positional layer convolutional embedding layer of the encoder.

Example:

Expand Down Expand Up @@ -270,8 +283,14 @@ def __init__(
position_embeddings_type="relative",
rotary_embedding_base=10000,
max_source_positions=5000,
left_max_position_embeddings=64,
right_max_position_embeddings=8,
conv_depthwise_kernel_size=31,
conformer_conv_dropout=0.1,
non_causal_depth_wise_conv=True,
skip_feature_encoder=False,
skip_encoder_layer_norm=False,
skip_pos_conv_embed=False,
**kwargs,
):
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
Expand Down Expand Up @@ -302,6 +321,8 @@ def __init__(
self.max_source_positions = max_source_positions
self.position_embeddings_type = position_embeddings_type
self.rotary_embedding_base = rotary_embedding_base
self.left_max_position_embeddings = left_max_position_embeddings
self.right_max_position_embeddings = right_max_position_embeddings

if (
(len(self.conv_stride) != self.num_feat_extract_layers)
Expand All @@ -318,6 +339,7 @@ def __init__(
# Conformer-block related
self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
self.conformer_conv_dropout = conformer_conv_dropout
self.non_causal_depth_wise_conv = non_causal_depth_wise_conv

# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
Expand Down Expand Up @@ -358,6 +380,10 @@ def __init__(
self.tdnn_dilation = list(tdnn_dilation)
self.xvector_output_dim = xvector_output_dim

self.skip_feature_encoder = skip_feature_encoder
self.skip_encoder_layer_norm = skip_encoder_layer_norm
self.skip_pos_conv_embed = skip_pos_conv_embed

@property
def inputs_to_logits_ratio(self):
return functools.reduce(operator.mul, self.conv_stride, 1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Wav2Vec2Conformer BERT checkpoint."""


import argparse

import torch
from seamless_communication.models.conformer_shaw import load_conformer_shaw_model

from transformers import (
AutoProcessor,
Wav2Vec2ConformerConfig,
Wav2Vec2ConformerForPreTraining,
logging,
)


logging.set_verbosity_info()
logger = logging.get_logger(__name__)


wav2vec_convert_list = [
("encoder", "wav2vec2_conformer.encoder"),
("encoder_frontend.model_dim_proj", "feature_projection.projection"),
("encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"),
("encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"),
("encoder.inner.layers", "encoder.layers"),
("encoder.inner_layer_norm", "encoder.layer_norm"),
("encoder.adaptor_layers", "adapter.layers"),
("inner_proj", "intermediate_dense"),
("self_attn.output_proj", "self_attn.linear_out"),
("output_proj", "output_dense"),
("self_attn.k_proj", "self_attn.linear_k"),
("self_attn.v_proj", "self_attn.linear_v"),
("self_attn.q_proj", "self_attn.linear_q"),
("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"),
("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"),
("self_attn.sdpa.rel_k_embed", "self_attn.distance_embedding"),
("self_attn.sdpa.r_proj", "self_attn.linear_pos"),
("conv.pointwise_conv1", "conv_module.pointwise_conv1"),
("conv.pointwise_conv2", "conv_module.pointwise_conv2"),
("conv.depthwise_conv", "conv_module.depthwise_conv"),
("conv.layer_norm", "conv_module.depthwise_layer_norm"),
("conv_layer_norm", "conv_module.layer_norm"),
("encoder.proj1", "intermediate_ffn.intermediate_dense"),
("encoder.proj2", "intermediate_ffn.output_dense"),
("encoder.layer_norm", "inner_layer_norm"),
("quantizer.entry_proj", "quantizer.weight_proj"),
("final_proj", "project_hid"),
("final_target_proj", "project_q"),
("masker.temporal_mask_embed", "wav2vec2_conformer.masked_spec_embed"),
("quantizer.entries", "quantizer.codevectors"),
]


def param_count(model):
return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])


def _convert_model(
original_model,
hf_model,
convert_list,
):
state_dict = original_model.state_dict()

for k, v in list(state_dict.items()):
new_key = k
for old_layer_name, new_layer_name in convert_list:
if old_layer_name in new_key:
new_key = new_key.replace(old_layer_name, new_layer_name)

# must do it by hand
if ".layer_norm" in new_key and new_key.split(".layer_norm")[0][-1].isnumeric():
new_key = new_key.replace("layer_norm", "final_layer_norm")

state_dict[new_key] = state_dict.pop(k)

extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
extra_keys = set({k for k in extra_keys if "num_updates" not in k}) # filter unecessary param
missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
hf_model.load_state_dict(state_dict, strict=False)
n_params = param_count(hf_model)

logger.info(f"model loaded: {round(n_params/1e6,1)}M params")

hf_model.eval()
del state_dict

return hf_model


@torch.no_grad()
def convert_wav2vec2_conformer_checkpoint(
checkpoint_path,
pytorch_dump_folder_path,
config_path=None,
repo_id=None,
process_path=None,
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = Wav2Vec2ConformerConfig.from_pretrained(config_path, hidden_act="swish")
else:
config = Wav2Vec2ConformerConfig(apply_spec_augment=False)

hf_wav2vec = Wav2Vec2ConformerForPreTraining(config)

model = load_conformer_shaw_model(checkpoint_path, dtype=torch.float32)
model.eval()

hf_wav2vec = _convert_model(model, hf_wav2vec, wav2vec_convert_list)

hf_wav2vec.save_pretrained(pytorch_dump_folder_path)

if repo_id:
hf_wav2vec.push_to_hub(repo_id)

if process_path:
processor = AutoProcessor.from_pretrained(process_path)
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

ylacombe marked this conversation as resolved.
Show resolved Hide resolved
processor.feature_extractor.padding_value = 1
processor.save_pretrained(pytorch_dump_folder_path)

if repo_id:
processor.push_to_hub(repo_id)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--checkpoint_path", default="conformer_shaw", type=str, help="Path to seamless communication checkpoint"
)
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument("--repo_id", default=None, type=str, help="Push to this repo id if precised.")
parser.add_argument("--process_path", default=None, type=str, help="Push to this repo id if precised.")
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

args = parser.parse_args()
convert_wav2vec2_conformer_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.repo_id, args.process_path
)
Loading
Loading