Skip to content

Commit

Permalink
Add Wav2Vec2 base model (#4513)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4513

As titled.

Reviewed By: zonglinpengmeta

Differential Revision: D60619295

fbshipit-source-id: 00fd48029bc2413cf2a4a1453c80bbf65d29c57f
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Aug 2, 2024
1 parent 1090bcd commit 4483bb6
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions examples/cadence/models/wav2vec2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Example script for exporting simple models to flatbuffer

import logging

from executorch.backends.cadence.aot.ops_registrations import * # noqa

import torch

from executorch.backends.cadence.aot.export_example import export_model
from torchaudio.models.wav2vec2.model import wav2vec2_model, Wav2Vec2Model

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


def main() -> None:
# The wrapper is needed to avoid issues with the optional second arguments
# of Wav2Vec2Models.
class Wav2Vec2ModelWrapper(torch.nn.Module):
def __init__(self, model: Wav2Vec2Model):
super().__init__()
self.model = model

def forward(self, x):
out, _ = self.model(x)
return out

_model = wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=0.0,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
aux_num_out=None,
)
_model.eval()

model = Wav2Vec2ModelWrapper(_model)
model.eval()

# test input
audio_len = 1680
example_inputs = (torch.rand(1, audio_len),)

export_model(model, example_inputs)


if __name__ == "__main__":
main()

0 comments on commit 4483bb6

Please sign in to comment.