Skip to content

Add Moonshine to KerasHub #2093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
8037ed0
init: Add MoonshineBackbone files
harshaljanjani Feb 10, 2025
51a40b8
feat: Make backbone test suite more robust
harshaljanjani Feb 10, 2025
098781e
feat: Exactness to the original and robustness of test cases
harshaljanjani Feb 11, 2025
047de1f
fix: Support stacked encoder layers from original implementation
harshaljanjani Feb 11, 2025
885f77f
TODO: Fix layer names
harshaljanjani Feb 12, 2025
805a806
fix: Add __init__ file
harshaljanjani Feb 12, 2025
9f579c0
Merge branch 'master' into moonshine
harshaljanjani Feb 12, 2025
aebeac7
fix: Correct subclassing and make ops more robust
harshaljanjani Feb 12, 2025
60112d5
feat: Incorporate feedback for Moonshine
harshaljanjani Feb 16, 2025
10cff1e
refactor: Move super.build() calls to the beginning of build() functions
harshaljanjani Feb 18, 2025
8dac22f
fix: Resolve API issue and fix duplicate parameters in attention
harshaljanjani Feb 18, 2025
2bacaf2
init: Add MoonshineDecoderBlock files (TODO: MoonshineDecoder)
harshaljanjani Feb 19, 2025
3af8498
feat: Add MoonshineDecoder with questionable tolerance
harshaljanjani Feb 20, 2025
2a2fcb9
fix: Fix decoder numerics (TODO: serialization and tokenizer)
harshaljanjani Feb 21, 2025
e05d1ed
feat: Add Tokenizer and SentencePiece model files
harshaljanjani Feb 22, 2025
9130d2c
refactor: API modification and temporarily removed TestCase
harshaljanjani Feb 22, 2025
b4e1ae9
chore: Update HF params (TODO: Resolve numerics issue)
harshaljanjani Feb 24, 2025
524e052
fix: Refactor model components: improve documentation, fix numerical …
harshaljanjani Feb 27, 2025
6cd6cd9
fix: Update checkpoint paths to include base directory for encoder, p…
harshaljanjani Feb 27, 2025
1e330b8
refactor: Rename arguments for clarity in Moonshine layers and unit t…
harshaljanjani Feb 28, 2025
64bcd63
feat: Add decoder to MoonshineBackbone, enable mixed-precision traini…
harshaljanjani Feb 28, 2025
c8f82aa
feat: Revamp test suites, finalize MoonshineBackbone, and improve doc…
harshaljanjani Mar 3, 2025
706078f
test: Add unit tests for Moonshine layers including InvFreqInitialize…
harshaljanjani Mar 3, 2025
57ad858
fix: TensorFlow compatibility in MoonshineInvFreqInitializer test
harshaljanjani Mar 3, 2025
67d01a9
refactor: Remove MoonshinePreprocessor and update related tests and i…
harshaljanjani Mar 8, 2025
2c76289
clean up: Use ReversibleEmbedding and shorten the weights conversion …
harshaljanjani Mar 9, 2025
a32d292
refactor: Simplify input handling in MoonshineBackbone and remove unu…
harshaljanjani Mar 9, 2025
189d39e
refactor: Update MoonshineBackbone, remove testable components, and f…
harshaljanjani Mar 10, 2025
3e236bb
feat: Add padding mask support, make the logits() function for a trai…
harshaljanjani Mar 11, 2025
e993ead
feat: Add trainable conditional generation task model, fix nits
harshaljanjani Mar 14, 2025
34ea915
refactor: Reformat MoonshineForConditionalGeneration and MoonshineBac…
harshaljanjani Mar 14, 2025
5599073
Merge branch 'keras-team:master' into moonshine
harshaljanjani Mar 14, 2025
f57fcd1
may fix JAX (Keras 3.5) backend tests: Update input handling to use a…
harshaljanjani Mar 15, 2025
c9e4d76
cleanup: Merge MoonshineAttention into a single class, remove unneces…
harshaljanjani Mar 15, 2025
80b1d9d
may fix JAX (Keras 3.5) backend: Fix initializer error in MoonshineRo…
harshaljanjani Mar 16, 2025
efc1424
finalizing changes: Complete generate() API with caching speedup
harshaljanjani Mar 18, 2025
18c06ef
refactor: Apply BART-inspired structural changes and optimize generate()
harshaljanjani Mar 23, 2025
d719ca2
bug fix: Fix the build() method in MoonshineAudioConverter, thus reso…
harshaljanjani Mar 24, 2025
578c7d0
task: Complete rewrite of the generation strategy in one go (TODO: up…
harshaljanjani Mar 31, 2025
0705d58
feat: Update weights conversion script
harshaljanjani Mar 31, 2025
3224a28
revert: Leave comments in the code for next review and revert caching…
harshaljanjani Apr 2, 2025
f5541f4
fix nits: Add warnings and the missing decoder_attention_mask param; …
harshaljanjani Apr 3, 2025
63a457f
another refactor: Single caching strategy, easily integrable into Ker…
harshaljanjani Apr 6, 2025
f961a06
fix nits: Remove unused encoder packer init, re-enable MHA tests; the…
harshaljanjani Apr 9, 2025
4f53d78
hooraayyy: The tests are yet to be fixed, but the task model works on…
harshaljanjani Apr 12, 2025
17ec26e
TODO: Fix JAX backend
harshaljanjani Apr 14, 2025
c59607d
end of sprint: Complete JAX backend implementation
harshaljanjani Apr 15, 2025
1a50443
outdated docstring: Update call_decoder_with_cache() docstring
harshaljanjani Apr 16, 2025
c8b9f41
chore: sweeping up the pixie dust, literally nothing new here
harshaljanjani Apr 21, 2025
29ddf2b
design choice: Revert dynamic mode in MoonshineRotaryEmbedding
harshaljanjani Apr 23, 2025
fb88c98
Merge branch 'master' into moonshine
harshaljanjani Apr 23, 2025
dc6c7c0
ughh: Messed up the merge probably
harshaljanjani Apr 23, 2025
e872fde
chore: Unused params from debugging removed, check out new api_gen.py…
harshaljanjani Apr 28, 2025
0f1ab28
refactor: Address comments - (Dynamic shapes + JIT compile)
harshaljanjani Apr 30, 2025
6727a3d
hoorayyy: All tests passing across all backends
harshaljanjani May 2, 2025
41b9c17
docstring nit: Remove cache_mode parameter from MoonshineMultiHeadAtt…
harshaljanjani May 3, 2025
8fc2460
checkpoint conv: Load the model from the preset instead to verify sav…
harshaljanjani May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
MobileNetImageConverter as MobileNetImageConverter,
)
from keras_hub.src.models.moonshine.moonshine_audio_converter import (
MoonshineAudioConverter as MoonshineAudioConverter,
)
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter as PaliGemmaImageConverter,
)
Expand Down
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,18 @@
from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import (
MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor,
)
from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
MoonshineAudioToText as MoonshineAudioToText,
)
from keras_hub.src.models.moonshine.moonshine_backbone import (
MoonshineBackbone as MoonshineBackbone,
)
from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import (
MoonshineSeq2SeqLMPreprocessor as MoonshineSeq2SeqLMPreprocessor,
)
from keras_hub.src.models.moonshine.moonshine_tokenizer import (
MoonshineTokenizer as MoonshineTokenizer,
)
from keras_hub.src.models.object_detector import (
ObjectDetector as ImageObjectDetector,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
from keras_hub.src.models.mistral.mistral_tokenizer import (
MistralTokenizer as MistralTokenizer,
)
from keras_hub.src.models.moonshine.moonshine_tokenizer import (
MoonshineTokenizer as MoonshineTokenizer,
)
from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer as PaliGemmaTokenizer,
Expand Down
Empty file.
278 changes: 278 additions & 0 deletions keras_hub/src/models/moonshine/moonshine_audio_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
import keras

try:
import tensorflow as tf
except ImportError:
tf = None

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone


@keras_hub_export("keras_hub.layers.MoonshineAudioConverter")
class MoonshineAudioConverter(AudioConverter):
"""Moonshine audio preprocessing layer.

This layer processes raw audio waveforms for the Moonshine ASR model. Audio
is formatted as a batched tensor at a 16kHz sample rate and validated for
length (0.1 to 64 seconds). The layer handles padding and optional
normalization. It does not contain trainable weights.

Args:
sampling_rate: int, optional. The audio sampling rate in Hz. Defaults to
16,000.
padding_value: float, optional. The value for padding. Defaults to 0.0.
do_normalize: bool, optional. Whether to normalize inputs. Defaults to
False.
**kwargs: Additional keyword arguments passed to the base AudioConverter
class for customizing the underlying preprocessing behavior.

Examples:
```python
import keras
from keras_hub.layers import MoonshineAudioConverter

# Create a dummy audio input (1 second at 16kHz).
dummy_audio = keras.ops.convert_to_tensor(
[[0.1] * 16000],
dtype="float32"
)
dummy_audio = keras.ops.expand_dims(dummy_audio, axis=-1)

# Initialize the preprocessor.
preprocessor = MoonshineAudioConverter(do_normalize=True)

# Process the audio.
processed_audio = preprocessor(dummy_audio)

# Output shape.
print(processed_audio.shape) # Expected: (1, 16000, 1) or padded length
```
"""

# References:
# Defined and formulated based on the UsefulSensors implementation of audio
# preprocessing logic (https://github.com/usefulsensors/moonshine/blob/main/moonshine/transcribe.py).

backbone_cls = MoonshineBackbone

def __init__(
self,
sampling_rate=16000,
padding_value=0.0,
do_normalize=False,
**kwargs,
):
super().__init__(**kwargs)
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.do_normalize = do_normalize

def call(
self,
inputs,
sampling_rate=None,
padding=None,
max_length=None,
pad_to_multiple_of=None,
):
# Validate sampling rate.
if sampling_rate is not None and sampling_rate != self.sampling_rate:
raise ValueError(
f"Expected sampling_rate {self.sampling_rate}, got "
f"{sampling_rate}"
)

# Ensure inputs are (batch_size, time_steps, 1).
input_shape = keras.ops.shape(inputs)
input_rank = len(input_shape)
if input_rank == 2:
processed_inputs = keras.ops.expand_dims(inputs, axis=-1)
elif input_rank == 3:
processed_inputs = inputs
else:
raise ValueError(
"Inputs must be mono audio: (batch_size, time_steps, 1)"
)

# Get original length and validate duration.
current_shape = keras.ops.shape(processed_inputs)
original_length = current_shape[1]
duration = (
keras.ops.cast(original_length, keras.backend.floatx())
/ self.sampling_rate
)
# Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/transcribe.py#L20
is_invalid_duration = keras.ops.logical_or(
keras.ops.less(duration, 0.1), keras.ops.greater(duration, 64.0)
)

def print_warning_fn():
import warnings

warnings.warn(
"Audio duration must be between 0.1 and 64 seconds. For "
"transcribing longer segments, pre-segment your audio and "
"provide shorter segments."
)
return keras.ops.convert_to_tensor(True, dtype="bool")

is_tf_symbolic = (
tf is not None
and hasattr(processed_inputs, "graph")
and hasattr(processed_inputs.graph, "as_graph_def")
)
use_tf_graph_ops = tf is not None and is_tf_symbolic
if use_tf_graph_ops:
_ = tf.cond(
is_invalid_duration,
print_warning_fn,
lambda: keras.ops.convert_to_tensor(False, dtype="bool"),
)
else:
if keras.ops.convert_to_numpy(is_invalid_duration):
print_warning_fn()

# Handle padding.
if padding == "longest":
target_length = original_length
if pad_to_multiple_of:
target_length = (
(target_length + pad_to_multiple_of - 1)
// pad_to_multiple_of
) * pad_to_multiple_of

needs_padding = keras.ops.greater(target_length, original_length)

def pad_fn():
padding_amount = target_length - original_length
paddings = [[0, 0], [0, padding_amount], [0, 0]]
if use_tf_graph_ops and keras.config.backend() != "tensorflow":
return tf.pad(
processed_inputs,
paddings,
mode="CONSTANT",
constant_values=float(self.padding_value),
)
else:
return keras.ops.pad(
processed_inputs,
paddings,
mode="constant",
constant_values=self.padding_value,
)

if use_tf_graph_ops:
processed_inputs = tf.cond(
needs_padding, pad_fn, lambda: processed_inputs
)
else:
processed_inputs = keras.ops.cond(
needs_padding, pad_fn, lambda: processed_inputs
)

elif padding == "max_length" and max_length is not None:
target_length_const = max_length
if pad_to_multiple_of:
target_length_const = (
(target_length_const + pad_to_multiple_of - 1)
// pad_to_multiple_of
) * pad_to_multiple_of

needs_padding = keras.ops.less(original_length, target_length_const)
needs_truncating = keras.ops.greater(
original_length, target_length_const
)

def pad_fn():
padding_amount = target_length_const - original_length
paddings = [[0, 0], [0, padding_amount], [0, 0]]
if use_tf_graph_ops and keras.config.backend() != "tensorflow":
return tf.pad(
processed_inputs,
paddings,
mode="CONSTANT",
constant_values=float(self.padding_value),
)
else:
return keras.ops.pad(
processed_inputs,
paddings,
mode="constant",
constant_values=self.padding_value,
)

def trunc_fn():
if use_tf_graph_ops and keras.config.backend() != "tensorflow":
return processed_inputs[:, :target_length_const, :]
else:
return keras.ops.slice(
processed_inputs,
[0, 0, 0],
[-1, target_length_const, -1],
)

if use_tf_graph_ops:
processed_inputs = tf.cond(
needs_padding,
pad_fn,
lambda: tf.cond(
needs_truncating, trunc_fn, lambda: processed_inputs
),
)
else:
needs_padding = keras.ops.less(
original_length, target_length_const
)
needs_truncating = keras.ops.greater(
original_length, target_length_const
)
needs_padding_bool = keras.ops.convert_to_numpy(needs_padding)
needs_truncating_bool = keras.ops.convert_to_numpy(
needs_truncating
)

if needs_padding_bool:
padding_amount = target_length_const - original_length
paddings = [[0, 0], [0, padding_amount], [0, 0]]
processed_inputs = keras.ops.pad(
processed_inputs,
paddings,
mode="constant",
constant_values=self.padding_value,
)
elif needs_truncating_bool:
processed_inputs = processed_inputs[
:, :target_length_const, :
]

# Normalize if enabled.
if self.do_normalize:
mean = keras.ops.mean(processed_inputs, axis=1, keepdims=True)
var = keras.ops.var(processed_inputs, axis=1, keepdims=True)
processed_inputs = (processed_inputs - mean) / keras.ops.sqrt(
var + 1e-7
)

return processed_inputs

def compute_output_shape(self, input_shape):
# [batch_size, time_steps] → [batch_size, time_steps, 1].
if len(input_shape) == 2 or len(input_shape) == 3:
return (input_shape[0], None, 1)
else:
raise ValueError("Input shape must be rank 2 or 3.")

def get_config(self):
config = super().get_config()
config.update(
{
"sampling_rate": self.sampling_rate,
"padding_value": self.padding_value,
"do_normalize": self.do_normalize,
}
)
return config
Loading
Loading