Skip to content

Add Moonshine to KerasHub #2093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: master
Choose a base branch
from

Conversation

harshaljanjani
Copy link
Collaborator

@harshaljanjani harshaljanjani commented Feb 12, 2025

Adds Moonshine ASR model to KerasHub:

Closes #2083.

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! I left some initial comments.
I would suggest following the format, structure and naming conventions similar to teh Whisper model here - https://github.com/keras-team/keras-hub/tree/master/keras_hub/src/models/whisper

  • add docstrings
  • convert backbone to a functional model
  • add a moonshine_audio_converter.py
  • Add a numerics verification colab to verify the implementation

@harshaljanjani
Copy link
Collaborator Author

Will make the changes at the earliest, thanks for the review!

@divyashreepathihalli
Copy link
Collaborator

you will need to run shell/api_gen.sh and also shell/format.sh at root to resolve the code formatting error

@harshaljanjani
Copy link
Collaborator Author

Thanks for the review, made the changes! The issue regarding the build still persists.

@harshaljanjani harshaljanjani self-assigned this Feb 19, 2025
@harshaljanjani
Copy link
Collaborator Author

Summary of Changes:

  1. Added MoonshineDecoderBlock (passes numeric checks, facing a few issues in the reversible embeddings, which keeps me from integrating the whole decoder, but I'll try to fix that and get back).
  2. Made a testable component for the encoder subclassed from keras.Model separate from the MoonshineBackbone class, as it's easier to test loading weights this way since each of the preprocessor, decoder and encoder has separate weight files.

@harshaljanjani
Copy link
Collaborator Author

TODO:

  1. Verify the build methods, as the sanity checks for serialization don’t pass, even though the numerics are aligned.
  2. Write weight conversion scripts.

@harshaljanjani
Copy link
Collaborator Author

Status of the PR:
Weight assignment works, but the numerics differ.

Outputs of the convert_moonshine_checkpoints.py script:

MD5 Checksum Comparison
Decoder Weights Assignment
Preprocessor Weights Assignment
Encoder Weights Assignment

@harshaljanjani harshaljanjani marked this pull request as ready for review April 12, 2025 12:25
… the PyTorch backend, integrated into the KerasHub infra!
@sachinprasadhs sachinprasadhs removed the WIP Pull requests which are work in progress and not ready yet for review. label Apr 14, 2025
@harshaljanjani
Copy link
Collaborator Author

Updated the Colab notebook with results from the latest commit. The PR is now open for review.
What's New?
The task model has been integrated across all three backends into the KerasHub infra, including the custom caching strategy used by Moonshine.

@divyashreepathihalli
Copy link
Collaborator

divyashreepathihalli commented Apr 15, 2025

I don't see the demo notebook with the KerasHub model implemented here, I am seeing a demo from the Huggingface model in the colab
please add the demo with KH model - and verify that the outputs match with model.generate

@harshaljanjani
Copy link
Collaborator Author

I don't see the demo notebook with the KerasHub model implemented here, I am seeing a demo from the Huggingface model in the colab
please add the demo with KH model - and verify that the outputs match with model.generate

@divyashreepathihalli The outputs you see across the first three cells are the KH model outputs for four test samples for each preset, using the generate() function. I've run tools/checkpoint_conversion/convert_moonshine_checkpoints.py in each of the cells across the three backends, which both, verifies the numerics, and contains the end-to-end example.

The cell links are:

  1. PyTorch Backend Output of keras_model.generate()
  2. TensorFlow Backend Output of keras_model.generate()
  3. JAX Backend Output of keras_model.generate()

You may also review the checkpoint conversion file to verify the same.

The HF model is only used in the last cell, where I point out a bug in the HF implementation and show how for the same sample, the KH model presets give good transcripts across all three backends. (The sample used in this test is the "Female Clear Voice (Maximum Length - 64 Sec)" one.)

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Apr 19, 2025

@mattdangerw / @abheesht17 / @divyashreepathihalli Whenever you have a chance, could you please take a look at this PR and the notebook, thanks!

harshaljanjani and others added 5 commits April 21, 2025 14:15
The rope_scaling parameter was much more of a direct port from HF, in which it took a dict and pulled the type key from it. The Moonshine presets nowhere explicitly use the dynamic mode, and it isn't crucial to the model. If it is necessary in the future, sure, but for a seminal port, I think it's best to keep it out. It's best to inherit from the KH RotaryEmbedding class and leave the scaling_factor arg upto it instead, works perfectly fine as a replacement and is much more integrated into the existing infra.
@mattdangerw
Copy link
Member

Dropping a few comments. I think we need still need to get the generation here working similar to other models, make the preprocessing be actual preprocessing (no weights!). I still think a clearer high level colab with intended usage might help clarify things.

  • Do some weight conversion, upload to huggingface or kaggle (doesn't matter which) on your own user.
  • Make a colab that does not touch huggingface at all that shows the intended usage here.
  • Try to show some of the usages here Add Moonshine to KerasHub #2093 (comment)

How much of this is working today? Have we tried running fine-tuning? That will run preprocessing via a tf.data.Dataset map, does that work?

!pip install git+https://github.com/harshaljanjani/keras-hub@moonshine

import os
os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch" with zero other changes.

import keras
import keras_hub

audio_to_text = keras_hub.models.AudioToText.from_preset(
    "hf://harshaljanjani/keras-moonshine",
)

audio_to_text.generate(audio_tensor)
audio_to_text.generate(audio_batch)

audio_to_text.compile(sampler="top_k")
audio_to_text.generate(audio_tensor)

audio_to_text.compile(...)
audio_to_text.enable_lora(4)  # Optional.
audio_to_text.fit(audio_dataset)
autio_to_text.generate(audio_batch)

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Apr 29, 2025

Will check the comments out, thanks for the review @mattdangerw. I left a few replies, I'd love to hear your opinion on a few non-trivial things as mentioned in the replies; I'll proceed to make changes on the others.

How much of this is working today? Have we tried running fine-tuning? That will run preprocessing via a tf.data.Dataset map, does that work?

I haven't tested fine-tuning yet, but I'll see what I can do. Since you mentioned that the change in the generate() strategy was key, I focused on it for this round.

- MoonshineAudioConverter now has no trainable weights, all feature extraction is moved to the MoonshineBackbone

- Removed logits() function and used self.token_embedding(reverse=True) instead

- Resolved test_causal_lm_basics() for all backends, thus resolving tf.data.Dataset.map compatibility issues on JAX and Torch backends.

- Removed 64 second test file.
@harshaljanjani
Copy link
Collaborator Author

Addressed reviews - (JIT compile + dynamic shapes issue). Looking forward to guidance regarding the same, I'll try to see if I can solve it in the mean time.

Fixed JIT compile issues on TensorFlow and JAX without unnecessary shenanigans

Reverted to KerasNLP style of caching without stateful cache modes.
@harshaljanjani
Copy link
Collaborator Author

The PR should be ready for the next round of reviews @mattdangerw. Here's the new Colab you mentioned. I've tested the functionality with dummy inputs for now; hope you don't mind! I'll check the weights upload thing and the presets once the design is approved.

  1. Functionality Tests Notebook Independent From HF.
  2. Same Outputs Notebook, Updated To The Current PR's Version.

@harshaljanjani harshaljanjani requested a review from mattdangerw May 2, 2025 16:19
@divyashreepathihalli
Copy link
Collaborator

The PR should be ready for the next round of reviews @mattdangerw. Here's the new Colab you mentioned. I've tested the functionality with dummy inputs for now; hope you don't mind! I'll check the weights upload thing and the presets once the design is approved.

  1. Functionality Tests Notebook Independent From HF.
  2. Same Outputs Notebook, Updated To The Current PR's Version.

please add demo colabs, verifications etc to PR descriptions so that it is easier to find

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented May 2, 2025

please add demo colabs, verifications etc to PR descriptions so that it is easier to find

Apologies, the end-to-end demo notebook has been linked in the PR description from the beginning. I've just linked the functionality tests I added today in the PR description!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Moonshine to KerasHub
5 participants