Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions audiotools/core/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setup_whisper(
).to(self.whisper_device)
self.is_initialized = True

def get_whisper_features(self) -> torch.Tensor:
def get_whisper_features(self, **kwargs) -> torch.Tensor:
"""Preprocess audio signal as per the whisper model's training config.

Returns
Expand Down Expand Up @@ -49,11 +49,12 @@ def get_whisper_features(self) -> torch.Tensor:
raw_speech,
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
return_tensors="pt",
**kwargs
).input_features

return input_features

def get_whisper_transcript(self) -> str:
def get_whisper_transcript(self, **kwargs) -> str:
"""Get the transcript of the audio signal using the whisper model.

Returns
Expand All @@ -69,12 +70,12 @@ def get_whisper_transcript(self) -> str:

with torch.inference_mode():
input_features = input_features.to(self.whisper_device)
generated_ids = self.whisper_model.generate(inputs=input_features)
generated_ids = self.whisper_model.generate(input_features=input_features, **kwargs)

transcription = self.whisper_processor.batch_decode(generated_ids)
transcription = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)
return transcription[0]

def get_whisper_embeddings(self) -> torch.Tensor:
def get_whisper_embeddings(self, **kwargs) -> torch.Tensor:
"""Get the last hidden state embeddings of the audio signal using the whisper model.

Returns
Expand All @@ -92,6 +93,6 @@ def get_whisper_embeddings(self) -> torch.Tensor:

with torch.inference_mode():
input_features = input_features.to(self.whisper_device)
embeddings = encoder(input_features)
embeddings = encoder(input_features, **kwargs)

return embeddings.last_hidden_state