Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 41 additions & 31 deletions backend/app/services/music_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,21 @@ def _pad_audio_token(token):
else:
pipeline._unload()

# Postprocess - load codec and decode
# -------------------------------------------------------------------------
# "Decoding audio..." step
# -------------------------------------------------------------------------
# 1. HeartCodec (if lazy) is loaded on codec_device (CPU on MPS, GPU on CUDA).
# 2. Token frames from HeartMuLa are moved to codec device and passed to
# pipeline.codec.detokenize(codes). Detokenize:
# - Pads/repeats codes to min_samples, slices into overlapping segments.
# - For each segment calls FlowMatching.inference_codes(codes_input, ...).
# - FlowMatching uses ResidualVQ.get_output_from_indices(codes): codes
# are indices into a codebook of size config.codebook_size (8192), so
# valid indices are 0..codebook_size-1. Any token >= codebook_size
# (e.g. EOS or model glitch) causes "index X is out of bounds for
# dimension 1 with size 8192". We clamp tokens to [0, codebook_size-1].
# 3. Output wav is float waveform; we then save via torchaudio.
# -------------------------------------------------------------------------
if callback is not None:
callback(95, "Decoding audio...")

Expand All @@ -851,9 +865,11 @@ def _pad_audio_token(token):
raise RuntimeError("Cannot load HeartCodec: codec_path not available")

# Move frames to codec device (keep dtype as long for indexing)
# frames contains token IDs (integers) used as indices, so dtype must remain long
# Explicitly preserve torch.long dtype when moving to device (critical for MPS)
frames_for_codec = frames.to(device=pipeline.codec_device, dtype=torch.long)
# Clamp token IDs to codec codebook range [0, codebook_size-1] to avoid
# "index X is out of bounds for dimension 1 with size 8192" in ResidualVQ.
codebook_size = getattr(pipeline.codec.config, "codebook_size", 8192)
frames_for_codec = frames_for_codec.clamp(0, codebook_size - 1)
wav = pipeline.codec.detokenize(frames_for_codec)

# Cleanup codec if using lazy loading (free VRAM for next generation)
Expand All @@ -875,7 +891,12 @@ def _pad_audio_token(token):
elif not lazy_codec:
pipeline._unload()

torchaudio.save(save_path, wav.to(torch.float32).cpu(), 48000)
# Ensure plain CPU tensor before save - some backends (e.g. torchaudio on macOS) reject
# tensors that originated on MPS ("invalid type: 'torch.mps.FloatTensor'"). Round-trip
# via numpy strips device association so the backend receives a standard CPU tensor.
wav_cpu = wav.to(torch.float32).cpu()
wav_plain = torch.from_numpy(wav_cpu.numpy().copy()).to(torch.float32)
torchaudio.save(save_path, wav_plain, 48000)

# Store the custom method on the pipeline instance
pipeline.generate_with_callback = generate_with_callback
Expand Down Expand Up @@ -1549,45 +1570,38 @@ def patched_warmup(model, device_map, hf_quantizer):
os.environ["HF_HUB_DISABLE_CACHING_ALLOCATOR_WARMUP"] = "1"

try:
# IMPORTANT: For MPS, we need to use float16 (not float32) for optimal performance
# MPS has native support for float16 operations which are much faster than float32
print("[Apple Metal] Loading models with float16 precision for optimal MPS performance", flush=True)
# HeartMuLa on MPS for fast generation; HeartCodec on CPU to avoid
# "invalid type: 'torch.mps.FloatTensor'" during decode (detokenize/save path).
cpu_device = torch.device("cpu")
mps_device = torch.device("mps")
print("[Apple Metal] Loading HeartMuLa on MPS (generation), HeartCodec on CPU (decode)", flush=True)
pipeline = HeartMuLaGenPipeline.from_pretrained(
model_path,
device={
"mula": torch.device("mps"),
"codec": torch.device("mps"),
"mula": mps_device,
"codec": cpu_device,
},
dtype={
"mula": torch.float16, # Use float16 for MPS acceleration
"codec": torch.float16, # Use float16 for MPS acceleration
"mula": torch.float16,
"codec": torch.float32,
},
version=version,
)

# Verify models are on MPS and explicitly set device attributes
mps_device = torch.device("mps")

# Ensure pipeline device attributes are set correctly
pipeline.mula_device = mps_device
pipeline.codec_device = mps_device
pipeline.codec_device = cpu_device
pipeline.mula_dtype = torch.float16
pipeline.codec_dtype = torch.float16
pipeline.codec_dtype = torch.float32

# Verify and correct model device placement
# Note: Accessing _mula and _codec (private attributes) is necessary here
# because the pipeline library doesn't provide public methods for device verification
if hasattr(pipeline, '_mula') and pipeline._mula is not None:
try:
# Get first parameter's device, or handle case where model has no parameters
mula_params = list(pipeline._mula.parameters())
if mula_params:
mula_device = mula_params[0].device
print(f"[Apple Metal] HeartMuLa model device: {mula_device}", flush=True)
if mula_device.type != 'mps':
logger.warning(f"[MPS] HeartMuLa model is on {mula_device}, not MPS! This will be slow.")
print(f"[Apple Metal] WARNING: HeartMuLa is on {mula_device}, moving to MPS...", flush=True)
# Explicitly set both device and dtype for consistency
pipeline._mula = pipeline._mula.to(device=mps_device, dtype=torch.float16)
print(f"[Apple Metal] HeartMuLa moved to MPS with float16 precision", flush=True)
else:
Expand All @@ -1597,24 +1611,20 @@ def patched_warmup(model, device_map, hf_quantizer):

if hasattr(pipeline, '_codec') and pipeline._codec is not None:
try:
# Get first parameter's device, or handle case where model has no parameters
codec_params = list(pipeline._codec.parameters())
if codec_params:
codec_device = codec_params[0].device
print(f"[Apple Metal] HeartCodec model device: {codec_device}", flush=True)
if codec_device.type != 'mps':
logger.warning(f"[MPS] HeartCodec model is on {codec_device}, not MPS! This will be slow.")
print(f"[Apple Metal] WARNING: HeartCodec is on {codec_device}, moving to MPS...", flush=True)
# Explicitly set both device and dtype for consistency
pipeline._codec = pipeline._codec.to(device=mps_device, dtype=torch.float16)
print(f"[Apple Metal] HeartCodec moved to MPS with float16 precision", flush=True)
if codec_device.type != 'cpu':
print(f"[Apple Metal] Moving HeartCodec to CPU for reliable decode (avoids MPS tensor errors)...", flush=True)
pipeline._codec = pipeline._codec.to(device=cpu_device, dtype=torch.float32)
print(f"[Apple Metal] HeartCodec on CPU (float32)", flush=True)
else:
logger.warning("[MPS] HeartCodec model has no parameters - cannot verify device")
except Exception as e:
logger.warning(f"[MPS] Failed to verify HeartCodec device: {e}")

print("[Apple Metal] MPS pipeline loaded successfully with float16 precision", flush=True)
print("[Apple Metal] All models are on MPS device for hardware acceleration", flush=True)
print("[Apple Metal] MPS pipeline loaded: HeartMuLa on MPS, HeartCodec on CPU", flush=True)
return patch_pipeline_with_callback(pipeline, sequential_offload=False)
finally:
# Restore original function if we patched it
Expand Down