Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device mismatch error in Whisper model during feature extraction #35866

Merged
merged 6 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,13 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu")
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
yielding results similar to cpu computing with 1e-5 tolerance.
"""
waveform = torch.from_numpy(waveform).type(torch.float32)
waveform = torch.from_numpy(waveform).to(device, torch.float32)
window = torch.hann_window(self.n_fft, device=device)

window = torch.hann_window(self.n_fft)
if device != "cpu":
waveform = waveform.to(device)
window = window.to(device)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
if device != "cpu":
mel_filters = mel_filters.to(device)
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
mel_spec = mel_filters.T @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
Expand Down
7 changes: 4 additions & 3 deletions tests/models/whisper/test_feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,9 @@ def test_torch_integration_batch(self):
)
# fmt: on

input_speech = self._load_datasamples(3)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
with torch.device("cuda"):
input_speech = self._load_datasamples(3)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (3, 80, 3000))
torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)