@@ -2292,16 +2292,15 @@ def get_subsampled_output_lengths(self, input_lengths):
22922292 def encoder_seq_length (self ):
22932293 return self .get_subsampled_output_lengths (self .seq_length )
22942294
2295- def create_and_check_model_forward (self , config , inputs_dict , freeze_encoder = False ):
2296- model = WhisperForAudioClassification (config = config ).to (torch_device ).eval ()
2297-
2298- if freeze_encoder :
2299- model .freeze_encoder ()
2295+ def create_and_check_model_forward (self , config , inputs_dict , use_weighted_layer_sum = False ):
2296+ config .use_weighted_layer_sum = use_weighted_layer_sum
2297+ model = WhisperForAudioClassification (config = config )
2298+ model .to (torch_device ).eval ()
23002299
23012300 input_features = inputs_dict ["input_features" ]
23022301
2303- # first forward pass
2304- last_hidden_state = model (input_features ).logits
2302+ with torch . no_grad ():
2303+ last_hidden_state = model (input_features ).logits
23052304
23062305 self .parent .assertTrue (last_hidden_state .shape , (13 , 2 ))
23072306
@@ -2336,6 +2335,14 @@ def test_forward_signature(self):
23362335 expected_arg_names = ["input_features" , "head_mask" , "encoder_outputs" ]
23372336 self .assertListEqual (arg_names [: len (expected_arg_names )], expected_arg_names )
23382337
2338+ def test_forward_pass (self ):
2339+ config_and_inputs = self .model_tester .prepare_config_and_inputs ()
2340+ self .model_tester .create_and_check_model_forward (* config_and_inputs )
2341+
2342+ def test_forward_pass_weighted_layer_sum (self ):
2343+ config_and_inputs = self .model_tester .prepare_config_and_inputs ()
2344+ self .model_tester .create_and_check_model_forward (* config_and_inputs , use_weighted_layer_sum = True )
2345+
23392346 @unittest .skip (reason = "Some undefined behavior encountered with tiny versions of this model. Skip for now." )
23402347 def test_cpu_offload (self ):
23412348 pass
0 commit comments