Skip to content

Speculative decoing is surprisingly slow on Whisper-large-v3 #32366

@changyeli

Description

@changyeli

System Info

  • transformers version: 4.43.3
  • Platform: Linux-6.5.0-1025-oracle-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100 80GB PCIe

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm working on a private dataset using whisper-large-v3 and I tried to use speculative decoding to speed up as shown in this blog.

Here is the code:

def map_to_pred(batch):
    """
    Perform inference on an audio batch

    Parameters:
        batch (dict): A dictionary containing audio data and other related information.

    Returns:
        dict: The input batch dictionary with added prediction and transcription fields.
    """
    audio = batch['audio']
    # handle short-form and long-form automatically
    inputs = processor(
        audio['array'],
        sampling_rate=audio['sampling_rate'],
        return_tensors="pt",
        return_attention_mask=True,
        truncation=False,
        padding="longest")
    if inputs.input_features.shape[-1] < 3000:
        inputs = processor(
            audio['array'],
            sampling_rate=audio['sampling_rate'],
            return_tensors="pt",
            return_attention_mask=True)
    inputs = inputs.to(device, torch.float16)
    with torch.no_grad():
        predicted_ids = model.generate(
            input_features=inputs['input_features'],
            attention_mask=inputs['attention_mask'],
            begin_suppress_tokens=None,
            assistant_model=assistant_model,
        )
    preds = processor.batch_decode(
        predicted_ids, skip_special_tokens=True, normalize=False)[0]
    batch['path'] = audio['path']
    batch['prediction'] = preds
    return batch


if __name__ == "__main__":
      device = "cuda:2"
      torch_dtype = torch.float16
      data = load_dataset(
          "audiofolder",
          data_dir=config['DATA']['target'],)
      data = data.cast_column('audio', Audio(sampling_rate=16_000))
      processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
      model = WhisperForConditionalGeneration.from_pretrained(
          "openai/whisper-large-v3",
          low_cpu_mem_usage=True,
          use_safetensors=True,
          attn_implementation="sdpa",
          device_map=device,
          torch_dtype=torch_dtype)
      model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
            language="english", task="transcribe")
      # speculative decoding
      assistant_model = WhisperForCausalLM.from_pretrained(
          "distil-whisper/distil-large-v3",
          torch_dtype=torch_dtype,
          low_cpu_mem_usage=True,
          use_safetensors=True,
          attn_implementation="sdpa"
      )
      assistant_model.to(device)
      result = avh_data.map(
          map_to_pred, remove_columns=['audio'])

Expected behavior

As stated in the blog post, it is expected that speculative decoding can lead to 2x faster inference. However, I'm getting ~40s/examples and taking ~30 hours to finish an audio dataset with ~3000 samples. Without speculative decoding, I can get ~4s/examples and take ~8 hours to finish.

I also get the following message while passing attention_mask to model.generate:

The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.

Did I set something wrong here? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions