Skip to content

[BUG] DAC.from_latents does not match the forward pass with missing STE #43819

@harshaljanjani

Description

@harshaljanjani

System Info

  • transformers version: 5.0.0.dev0
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • huggingface_hub version: 1.3.2
  • safetensors version: 0.7.0
  • accelerate version: 1.12.0
  • Accelerate config: not installed
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.9.1+cu128 (CUDA)
  • GPU type: NVIDIA L4
  • NVIDIA driver version: 550.90.07
  • CUDA version: 12.4

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from datasets import load_dataset, Audio
from transformers import DacModel, AutoProcessor

model_id = "descript/dac_16khz"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DacModel.from_pretrained(model_id).to(device).eval()
processor = AutoProcessor.from_pretrained(model_id)
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[0]["audio"]["array"]
inputs = processor(
  raw_audio=audio_sample,
  sampling_rate=processor.sampling_rate,
  return_tensors="pt",
).to(device)
input_values = inputs["input_values"]

with torch.no_grad():
  encoded = model.encoder(input_values)
  quant_repr_fwd, audio_codes, proj_latents, _, _ = model.quantizer(encoded)
  quant_repr_from_lat, _ = model.quantizer.from_latents(proj_latents)
  print(quant_repr_fwd)
  print(quant_repr_from_lat)
  print(torch.max(quant_repr_fwd - quant_repr_from_lat).abs().item())

DAC.from_latents in DacResidualVectorQuantizer does not correctly apply the straight-through estimator before out_proj, unlike the correct DacVectorQuantize.forward pattern. This also breaks CI; the tests/models/dac/test_modeling_dac.py::DacIntegrationTest tests fail.

Current Output:

Image

Expected behavior

DAC.from_latents should exactly match the quantizer forward pass output.
tests/models/dac/test_modeling_dac.py::DacIntegrationTest::test_quantizer_from_latents_integration_0_dac_16khz && DacIntegrationTest::test_quantizer_from_latents_integration_1_dac_24khz && DacIntegrationTest::test_quantizer_from_latents_integration_2_dac_44khz integration tests pass without regressions

Output After the Fix:

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions