Skip to content

Commit f7f93e7

Browse files
thedebuggereustlb
authored andcommitted
Fix device mismatch error in Whisper model during feature extraction (huggingface#35866)
* Fix device mismatch error in whisper feature extraction * Set default device * Address code review feedback --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
1 parent 2f1864d commit f7f93e7

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

src/transformers/models/whisper/feature_extraction_whisper.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,13 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu")
129129
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
130130
yielding results similar to cpu computing with 1e-5 tolerance.
131131
"""
132-
waveform = torch.from_numpy(waveform).type(torch.float32)
132+
waveform = torch.from_numpy(waveform).to(device, torch.float32)
133+
window = torch.hann_window(self.n_fft, device=device)
133134

134-
window = torch.hann_window(self.n_fft)
135-
if device != "cpu":
136-
waveform = waveform.to(device)
137-
window = window.to(device)
138135
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
139136
magnitudes = stft[..., :-1].abs() ** 2
140137

141-
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
142-
if device != "cpu":
143-
mel_filters = mel_filters.to(device)
138+
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
144139
mel_spec = mel_filters.T @ magnitudes
145140

146141
log_spec = torch.clamp(mel_spec, min=1e-10).log10()

tests/models/whisper/test_feature_extraction_whisper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,9 @@ def test_torch_integration_batch(self):
298298
)
299299
# fmt: on
300300

301-
input_speech = self._load_datasamples(3)
302-
feature_extractor = WhisperFeatureExtractor()
303-
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
301+
with torch.device("cuda"):
302+
input_speech = self._load_datasamples(3)
303+
feature_extractor = WhisperFeatureExtractor()
304+
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
304305
self.assertEqual(input_features.shape, (3, 80, 3000))
305306
torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)

0 commit comments

Comments
 (0)