Skip to content

Commit

Permalink
Add whisper masking (#146)
Browse files Browse the repository at this point in the history
- Added masking in whisper encoder to ensure consistency in training and inference.
- Simplified release_config.yaml to serve as an example configuration.
  • Loading branch information
zqhuang211 authored Nov 9, 2024
1 parent 3787d64 commit 812f58c
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 93 deletions.
1 change: 1 addition & 0 deletions ultravox/model/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
inputs["audio_values"].squeeze_(0)
inputs["audio_token_start_idx"].squeeze_(0)
inputs["audio_token_len"].squeeze_(0)
inputs["audio_len"].squeeze_(0)

# No need to shift the labels as the model does it internally
labels = input_ids.clone()
Expand Down
1 change: 1 addition & 0 deletions ultravox/model/data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def fake_process(text, audio, return_tensors="pt", sampling_rate=16000):
"audio_values": torch.tensor([[[0.1, 0.2, 0.3]]]),
"audio_token_start_idx": torch.tensor([1]),
"audio_token_len": torch.tensor([2]),
"audio_len": torch.tensor([10]),
}


Expand Down
40 changes: 35 additions & 5 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def forward(
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_len: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
# the alt_* fields are needed for KL divergence loss
Expand Down Expand Up @@ -189,7 +190,7 @@ def forward(

# B x A/3200 x D
audio_tower_output = self.audio_tower.forward(
audio_values.to(self.audio_tower.dtype)
audio_values.to(self.audio_tower.dtype), audio_len=audio_len
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)

Expand Down Expand Up @@ -235,6 +236,7 @@ def prepare_inputs_for_generation(
audio_values: Optional[torch.FloatTensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
audio_len: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -263,6 +265,7 @@ def prepare_inputs_for_generation(
audio_token_start_idx - prefill_start_idx
)
model_input["audio_token_len"] = audio_token_len
model_input["audio_len"] = audio_len

return model_input

Expand Down Expand Up @@ -508,7 +511,9 @@ def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
return hidden_states


class ModifiedWhisperEncoder(whisper.WhisperEncoder):
class ModifiedWhisperEncoder(
whisper.WhisperEncoder, transformers.modeling_utils.ModuleUtilsMixin
):
"""
Encoder portion of OpenAI's Whisper model.
Expand All @@ -527,7 +532,7 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
def forward(
self,
input_features,
attention_mask=None,
audio_len=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -570,6 +575,31 @@ def forward(
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

# Create attention mask based on audio lengths to mask out padding tokens
# For each sample in batch:
# - Convert raw audio length to feature length after convolutions
# - Create boolean mask that is True for valid positions and False for padding
# - Convert to extended attention mask format expected by transformer layers
# (1.0 for positions to attend to, large negative for positions to ignore)
# This masking ensures consistent behavior between training and inference
# by preventing the model from attending to padding tokens in both cases
attention_mask = None
if audio_len != None:
audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
batch_size = hidden_states.shape[0]
max_seq_len = hidden_states.shape[1]
attention_mask = (
torch.arange(max_seq_len, device=hidden_states.device)[None, :]
.expand(batch_size, -1)
.lt(audio_feature_len.view(batch_size, 1))
)
attention_mask = self.get_extended_attention_mask(
attention_mask,
None,
device=hidden_states.device,
dtype=hidden_states.dtype,
)

# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
Expand All @@ -593,14 +623,14 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
None,
attention_mask,
layer_head_mask=(
head_mask[idx] if head_mask is not None else None
),
Expand Down
11 changes: 11 additions & 0 deletions ultravox/model/ultravox_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,24 @@ def __call__(
sampling_rate=sampling_rate,
padding="longest",
max_length=audio_len,
return_attention_mask=True,
**kwargs,
)
if "input_features" in x:
data["audio_values"] = x.input_features
else:
data["audio_values"] = x.input_values

# data["audio_len"] is the number of frames in the audio, used for creating attention masks in whisper encoder
if (
self.audio_padding == "max_length"
): # audio is padded to max length, so we rely on the attention mask to determine audio_len
data["audio_len"] = (
x.attention_mask.sum(-1) - 1
) # Whisper attention mask includes an extra 1 at the end that needs to be subtracted
else: # audio is not padded, so we can directly use the audio length
data["audio_len"] = [torch.as_tensor(data["audio_values"]).shape[-1]]

if text is not None:
assert isinstance(
text, str
Expand Down
6 changes: 0 additions & 6 deletions ultravox/training/configs/llama3_whisper.yaml

This file was deleted.

39 changes: 0 additions & 39 deletions ultravox/training/configs/llama3_whisper_kd.yaml

This file was deleted.

45 changes: 2 additions & 43 deletions ultravox/training/configs/release_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# SLM with ultravox & llama3.1, trained wtih knowledge distillation.
exp_name: "ultravox-v0_4"

# Make sure to accept the license agreement on huggingface hub
text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
audio_model: "openai/whisper-medium"
Expand All @@ -12,51 +10,12 @@ loss_config:
train_sets:
- name: librispeech-clean-continuation
- name: librispeech-other-continuation
- name: peoplespeech-clean-continuation
weight: 8
- name: commonvoice-en-continuation
weight: 8
- name: commonvoice-ar-continuation
weight: 0.2
- name: commonvoice-de-continuation
weight: 4
- name: commonvoice-es-continuation
weight: 3
- name: commonvoice-fr-continuation
weight: 4
- name: commonvoice-it-continuation
weight: 1.2
- name: commonvoice-ja-continuation
weight: 0.1
- name: commonvoice-pt-continuation
weight: 0.2
- name: commonvoice-ru-continuation
weight: 0.2
- name: librispeech-clean-transcription
- name: librispeech-other-transcription
- name: peoplespeech-clean-transcription
weight: 0.8
- name: commonvoice-en-transcription
weight: 0.8
- name: commonvoice-ar-transcription
weight: 0.02
- name: commonvoice-de-transcription
weight: 0.4
- name: commonvoice-es-transcription
weight: 0.3
- name: commonvoice-fr-transcription
weight: 0.4
- name: commonvoice-it-transcription
weight: 0.12
- name: commonvoice-ja-transcription
weight: 0.01
- name: commonvoice-pt-transcription
weight: 0.02
- name: commonvoice-ru-transcription
weight: 0.02

# Temporarily remove heysquad_human from val_sets as it causes the training to fail.
val_sets:
- name: covost2-en-de
- name: covost2-zh-en
- name: peoplespeech-clean-transcription

batch_size: 24
Expand Down

0 comments on commit 812f58c

Please sign in to comment.