From 2152d7986c4e727f840437371ee2dad686ad350b Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Tue, 11 Jun 2024 11:37:49 -0700 Subject: [PATCH] Cadence - Add RNNT joiner from torchaudio (#3920) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3920 As titled. Reviewed By: tarun292 Differential Revision: D58271588 fbshipit-source-id: e1e1d33a5a28ec2c0147b79b6e48cf59bdd01b7e --- examples/cadence/models/rnnt_joiner.py | 65 ++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 examples/cadence/models/rnnt_joiner.py diff --git a/examples/cadence/models/rnnt_joiner.py b/examples/cadence/models/rnnt_joiner.py new file mode 100644 index 0000000000..b6aaa7c91b --- /dev/null +++ b/examples/cadence/models/rnnt_joiner.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +import torch + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +from typing import Tuple + +from executorch.backends.cadence.aot.export_example import export_model + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +if __name__ == "__main__": + + class Joiner(torch.nn.Module): + def __init__( + self, input_dim: int, output_dim: int, activation: str = "relu" + ) -> None: + super().__init__() + self.linear = torch.nn.Linear(input_dim, output_dim, bias=True) + if activation == "relu": + # pyre-fixme[4]: Attribute must be annotated. + self.activation = torch.nn.ReLU() + elif activation == "tanh": + self.activation = torch.nn.Tanh() + else: + raise ValueError(f"Unsupported activation {activation}") + + def forward( + self, + source_encodings: torch.Tensor, + target_encodings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + joint_encodings = ( + source_encodings.unsqueeze(2).contiguous() + + target_encodings.unsqueeze(1).contiguous() + ) + activation_out = self.activation(joint_encodings) + output = self.linear(activation_out) + return output + + # Joiner + model = Joiner(256, 128) + + # Get dummy joiner inputs + source_encodings = torch.randn(1, 25, 256) + target_encodings = torch.randn(1, 10, 256) + + example_inputs = ( + source_encodings, + target_encodings, + ) + + export_model(model, example_inputs)