Skip to content

Moshi Generation Does Not Work as Expected #36160

@SeungyounShin

Description

@SeungyounShin

System Info

🐛 Bug Report

Description

The provided Moshi example code does not function correctly with the Transformers library. The generate function fails when attempting to generate new tokens, and an issue arises with the expected input formats.

And here is moshi_output.wav

moshi_bug_report.mp4

I tried different temperature settings, generation configurations, and other samples, but it only produces a static 'chijijik...' sound.

cc. @ylacombe

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from datasets import load_dataset, Audio
import torch, math
from transformers import MoshiForConditionalGeneration, AutoFeatureExtractor, AutoTokenizer
import soundfile as sf
import torch
import transformers
import os
import torch

# Disable all automatic compilation features
os.environ['TORCH_COMPILE'] = '0'
os.environ['TORCHDYNAMO_DISABLE'] = '1'  # Fully disables TorchDynamo
os.environ['TORCHDYNAMO_VERBOSE'] = '0'  # Suppresses unnecessary logs
os.environ['TORCHDYNAMO_RECOMPILE_LIMIT'] = '0'  # Avoid recompile limits

# Apply global config settings for eager mode
torch._dynamo.config.suppress_errors = True  # Avoids crashes and falls back to eager mode
torch._dynamo.config.cache_size_limit = 0  # Prevents recompilation limits
torch._dynamo.reset()  # Clears any cached compile traces


librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

device = "cuda"
# prepare user input audio 
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=24000))
audio_sample = librispeech_dummy[-1]["audio"]["array"] # (107520,)
# WAV_PATH = f"./audio/moshi_opening.wav"
# audio_sample, sample_rate = sf.read(WAV_PATH)
waveform_to_token_ratio = 1 / 1920

model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko", attn_implementation="eager", torch_dtype=torch.float16)
feature_extractor = AutoFeatureExtractor.from_pretrained("kmhf/hf-moshiko")
tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
model = model.to(device)

user_input_values = feature_extractor(raw_audio=audio_sample, sampling_rate=24000, return_tensors="pt").to(device=device, dtype=torch.float16)

# prepare moshi input values - we suppose moshi didn't say anything while the user spoke
moshi_input_values = torch.zeros_like(user_input_values.input_values) # (1, 1, 107520)

# prepare moshi input ids - we suppose moshi didn't say anything while the user spoke
num_tokens = math.ceil(moshi_input_values.shape[-1] * waveform_to_token_ratio)
input_ids = torch.ones((1, num_tokens), device=device, dtype=torch.int64) * tokenizer.encode("<pad>")[0]

# Force disable torch.compile inside Transformers
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.forward = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.forward
)
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.generate = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.generate
)
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.prepare_inputs_for_generation = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.prepare_inputs_for_generation
)

# generate 25 new tokens (around 2s of audio)
output = model.generate(
    input_ids=input_ids,
    user_input_values=user_input_values.input_values,
    moshi_input_values=moshi_input_values,
    max_new_tokens=50,
    temperature=0.8,
    do_sample=True,
)

text_tokens = output.sequences
# decode text tokens
text = tokenizer.decode(text_tokens[0], skip_special_tokens=True)
print(text)

# decode audio tokens
audio_waveforms = output.audio_sequences.squeeze(0).squeeze(0) # (L,)
audio_waveforms = audio_waveforms.double()

# cut audio for input length
audio_waveforms = audio_waveforms[:user_input_values.input_values.shape[-1]]

# save audio
sf.write("moshi_output.wav", audio_waveforms.cpu().numpy(), 24000)

Expected behavior

should produce sounds

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions