diff --git a/examples/cadence/models/wav2vec2.py b/examples/cadence/models/wav2vec2.py new file mode 100644 index 0000000000..5db9ea2a6d --- /dev/null +++ b/examples/cadence/models/wav2vec2.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch + +from executorch.backends.cadence.aot.export_example import export_model +from torchaudio.models.wav2vec2.model import wav2vec2_model, Wav2Vec2Model + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def main() -> None: + # The wrapper is needed to avoid issues with the optional second arguments + # of Wav2Vec2Models. + class Wav2Vec2ModelWrapper(torch.nn.Module): + def __init__(self, model: Wav2Vec2Model): + super().__init__() + self.model = model + + def forward(self, x): + out, _ = self.model(x) + return out + + _model = wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=0.1, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=0.1, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=0.0, + encoder_dropout=0.1, + encoder_layer_norm_first=False, + encoder_layer_drop=0.1, + aux_num_out=None, + ) + _model.eval() + + model = Wav2Vec2ModelWrapper(_model) + model.eval() + + # test input + audio_len = 1680 + example_inputs = (torch.rand(1, audio_len),) + + export_model(model, example_inputs) + + +if __name__ == "__main__": + main()