Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/multimodal_audio/whisper_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from transformers import WhisperForConditionalGeneration, WhisperProcessor

from llmcompressor import oneshot
from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.utils import dispatch_for_generation

Expand All @@ -12,6 +13,7 @@
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto")
model.config.forced_decoder_ids = None
processor = WhisperProcessor.from_pretrained(MODEL_ID)
model = prepare_for_calibration(model) # patch model (see #1574)

# Configure processor the dataset task.
processor.tokenizer.set_prefix_tokens(language="en", task="transcribe")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ target-version = ['py38']

[tool.isort]
profile = "black"
skip = ["src/llmcompressor/transformers/tracing/", "src/llmcompressor/version.py"]
skip = ["src/llmcompressor/transformers/tracing/", "src/llmcompressor/transformers/modeling/", "src/llmcompressor/version.py"]

[tool.mypy]
files = "src/guidellm"

[tool.ruff]
exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/transformers/tracing/"]
exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/transformers/tracing/", "src/llmcompressor/modeling/"]
lint.select = ["E", "F", "W"]
lint.extend-ignore = ["E203", "W605"]

Expand Down
3 changes: 3 additions & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .prepare import *
20 changes: 20 additions & 0 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from compressed_tensors.utils import replace_module
from transformers import PreTrainedModel

from llmcompressor.modeling.whisper import replace as replace_WhisperEncoder

__all__ = ["prepare_for_calibration"]

replacements = {
"WhisperEncoder": replace_WhisperEncoder,
}


def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
for name, module in model.named_modules():
cls_name = module.__class__.__name__
if cls_name in replacements:
new_module = replacements[cls_name](module)
replace_module(model, name, new_module)

return model
149 changes: 149 additions & 0 deletions src/llmcompressor/modeling/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# flake8: noqa
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# vllm-project: no copyright
# Adapated from modeling_whisper.py

import math

import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.whisper.modeling_whisper import WhisperEncoder


class WhisperEncoderPatched(WhisperEncoder):
"""
Patches whisper model to support CPU offloading, which is required for
the sequential calibration pipelines.

For the diff, see https://github.com/huggingface/transformers/pull/38994
"""
def __init__(
self, config, conv1, conv2, embed_positions, layers, layer_norm
):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop

embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

self.conv1 = conv1
self.conv2 = conv2

self.embed_positions = embed_positions

self.layers = layers
self.layer_norm = layer_norm

self.gradient_checkpointing = False
self.post_init()

def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
if input_features.shape[-1] != expected_seq_length:
raise ValueError(
f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
)

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

inputs_embeds = inputs_embeds.permute(0, 2, 1)
# PATCH: see https://github.com/huggingface/transformers/pull/38994
all_positions = torch.arange(self.embed_positions.num_embeddings, device=inputs_embeds.device)

hidden_states = inputs_embeds + self.embed_positions(all_positions)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (len(self.layers)), (
f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
)

for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True

if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)

def replace(module: WhisperEncoder) -> WhisperEncoderPatched:
return WhisperEncoderPatched(
module.config,
module.conv1,
module.conv2,
module.embed_positions,
module.layers,
module.layer_norm
)