Skip to content

Commit

Permalink
Add Wav2Vec2 Adapter Weights to Flax (huggingface#15566)
Browse files Browse the repository at this point in the history
* Add Wav2Vec2 Adapter Weights to Flax

* Suggested changes
  • Loading branch information
sanchit-gandhi authored Feb 9, 2022
1 parent 1f60bc4 commit 9e00566
Showing 1 changed file with 95 additions and 3 deletions.
98 changes: 95 additions & 3 deletions src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,73 @@ def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, te
return codevectors, perplexity


class FlaxWav2Vec2Adapter(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
# hidden_states require down-projection if feature dims don't match
if self.config.output_hidden_size != self.config.hidden_size:
self.proj = nn.Dense(
self.config.output_hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
else:
self.proj = self.proj_layer_norm = None

self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
# down-project hidden_states if required
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)

hidden_states = self.layers(hidden_states)

return hidden_states


class FlaxWav2Vec2AdapterLayer(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
self.conv = nn.Conv(
features=2 * self.config.output_hidden_size,
kernel_size=(self.config.adapter_kernel_size,),
strides=(self.config.adapter_stride,),
padding=((1, 1),),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)

def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.glu(hidden_states, axis=2)

return hidden_states


class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
self.layers = [
FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_adapter_layers)
]

def __call__(self, hidden_states):
for conv_layer in self.layers:
hidden_states = conv_layer(hidden_states)

return hidden_states


class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand Down Expand Up @@ -840,7 +907,9 @@ def __call__(
rngs=rngs,
)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
return self.module._get_feat_extract_output_lengths(input_lengths)


Expand All @@ -860,6 +929,8 @@ def setup(self):
else:
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")

self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None

def __call__(
self,
input_values,
Expand Down Expand Up @@ -905,6 +976,9 @@ def __call__(

hidden_states = encoder_outputs[0]

if self.adapter is not None:
hidden_states = self.adapter(hidden_states)

if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]

Expand All @@ -915,11 +989,15 @@ def __call__(
attentions=encoder_outputs.attentions,
)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""

add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
Expand All @@ -928,6 +1006,10 @@ def _conv_out_length(input_length, kernel_size, stride):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

return input_lengths


Expand Down Expand Up @@ -1021,11 +1103,17 @@ def __call__(

return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self,
input_lengths: Union[jnp.ndarray, int],
add_adapter: Optional[bool] = None,
):
"""
Computes the output length of the convolutional layers
"""

add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
Expand All @@ -1034,6 +1122,10 @@ def _conv_out_length(input_length, kernel_size, stride):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

return input_lengths


Expand Down

0 comments on commit 9e00566

Please sign in to comment.