-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Closed
Labels
Description
System Info
🐛 Bug Report
Description
The provided Moshi example code does not function correctly with the Transformers library. The generate function fails when attempting to generate new tokens, and an issue arises with the expected input formats.
And here is moshi_output.wav
moshi_bug_report.mp4
I tried different temperature settings, generation configurations, and other samples, but it only produces a static 'chijijik...' sound.
cc. @ylacombe
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
from datasets import load_dataset, Audio
import torch, math
from transformers import MoshiForConditionalGeneration, AutoFeatureExtractor, AutoTokenizer
import soundfile as sf
import torch
import transformers
import os
import torch
# Disable all automatic compilation features
os.environ['TORCH_COMPILE'] = '0'
os.environ['TORCHDYNAMO_DISABLE'] = '1' # Fully disables TorchDynamo
os.environ['TORCHDYNAMO_VERBOSE'] = '0' # Suppresses unnecessary logs
os.environ['TORCHDYNAMO_RECOMPILE_LIMIT'] = '0' # Avoid recompile limits
# Apply global config settings for eager mode
torch._dynamo.config.suppress_errors = True # Avoids crashes and falls back to eager mode
torch._dynamo.config.cache_size_limit = 0 # Prevents recompilation limits
torch._dynamo.reset() # Clears any cached compile traces
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
device = "cuda"
# prepare user input audio
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=24000))
audio_sample = librispeech_dummy[-1]["audio"]["array"] # (107520,)
# WAV_PATH = f"./audio/moshi_opening.wav"
# audio_sample, sample_rate = sf.read(WAV_PATH)
waveform_to_token_ratio = 1 / 1920
model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko", attn_implementation="eager", torch_dtype=torch.float16)
feature_extractor = AutoFeatureExtractor.from_pretrained("kmhf/hf-moshiko")
tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
model = model.to(device)
user_input_values = feature_extractor(raw_audio=audio_sample, sampling_rate=24000, return_tensors="pt").to(device=device, dtype=torch.float16)
# prepare moshi input values - we suppose moshi didn't say anything while the user spoke
moshi_input_values = torch.zeros_like(user_input_values.input_values) # (1, 1, 107520)
# prepare moshi input ids - we suppose moshi didn't say anything while the user spoke
num_tokens = math.ceil(moshi_input_values.shape[-1] * waveform_to_token_ratio)
input_ids = torch.ones((1, num_tokens), device=device, dtype=torch.int64) * tokenizer.encode("<pad>")[0]
# Force disable torch.compile inside Transformers
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.forward = torch._dynamo.disable(
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.forward
)
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.generate = torch._dynamo.disable(
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.generate
)
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.prepare_inputs_for_generation = torch._dynamo.disable(
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.prepare_inputs_for_generation
)
# generate 25 new tokens (around 2s of audio)
output = model.generate(
input_ids=input_ids,
user_input_values=user_input_values.input_values,
moshi_input_values=moshi_input_values,
max_new_tokens=50,
temperature=0.8,
do_sample=True,
)
text_tokens = output.sequences
# decode text tokens
text = tokenizer.decode(text_tokens[0], skip_special_tokens=True)
print(text)
# decode audio tokens
audio_waveforms = output.audio_sequences.squeeze(0).squeeze(0) # (L,)
audio_waveforms = audio_waveforms.double()
# cut audio for input length
audio_waveforms = audio_waveforms[:user_input_values.input_values.shape[-1]]
# save audio
sf.write("moshi_output.wav", audio_waveforms.cpu().numpy(), 24000)Expected behavior
should produce sounds