Skip to content

Commit 186aa6b

Browse files
[Whisper] Fix audio classification with weighted layer sum (#28563)
* fix * tests * fix test
1 parent 619ecfe commit 186aa6b

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

src/transformers/models/whisper/modeling_whisper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757

5858
logger = logging.get_logger(__name__)
5959

60+
_HIDDEN_STATES_START_POSITION = 1
61+
6062
_CONFIG_FOR_DOC = "WhisperConfig"
6163
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
6264

@@ -2957,6 +2959,11 @@ def forward(
29572959
output_hidden_states = (
29582960
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
29592961
)
2962+
if self.config.use_weighted_layer_sum:
2963+
output_hidden_states = True
2964+
elif output_hidden_states is None:
2965+
output_hidden_states = self.config.output_hidden_states
2966+
29602967
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
29612968

29622969
if encoder_outputs is None:
@@ -2969,7 +2976,8 @@ def forward(
29692976
)
29702977

29712978
if self.config.use_weighted_layer_sum:
2972-
hidden_states = torch.stack(encoder_outputs, dim=1)
2979+
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
2980+
hidden_states = torch.stack(hidden_states, dim=1)
29732981
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
29742982
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
29752983
else:

tests/models/whisper/test_modeling_whisper.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,16 +2292,15 @@ def get_subsampled_output_lengths(self, input_lengths):
22922292
def encoder_seq_length(self):
22932293
return self.get_subsampled_output_lengths(self.seq_length)
22942294

2295-
def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False):
2296-
model = WhisperForAudioClassification(config=config).to(torch_device).eval()
2297-
2298-
if freeze_encoder:
2299-
model.freeze_encoder()
2295+
def create_and_check_model_forward(self, config, inputs_dict, use_weighted_layer_sum=False):
2296+
config.use_weighted_layer_sum = use_weighted_layer_sum
2297+
model = WhisperForAudioClassification(config=config)
2298+
model.to(torch_device).eval()
23002299

23012300
input_features = inputs_dict["input_features"]
23022301

2303-
# first forward pass
2304-
last_hidden_state = model(input_features).logits
2302+
with torch.no_grad():
2303+
last_hidden_state = model(input_features).logits
23052304

23062305
self.parent.assertTrue(last_hidden_state.shape, (13, 2))
23072306

@@ -2336,6 +2335,14 @@ def test_forward_signature(self):
23362335
expected_arg_names = ["input_features", "head_mask", "encoder_outputs"]
23372336
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
23382337

2338+
def test_forward_pass(self):
2339+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
2340+
self.model_tester.create_and_check_model_forward(*config_and_inputs)
2341+
2342+
def test_forward_pass_weighted_layer_sum(self):
2343+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
2344+
self.model_tester.create_and_check_model_forward(*config_and_inputs, use_weighted_layer_sum=True)
2345+
23392346
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
23402347
def test_cpu_offload(self):
23412348
pass

0 commit comments

Comments
 (0)