Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Examples] Generalise Seq2Seq ASR to handle Whisper #19519

Merged
merged 13 commits into from
Nov 14, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ class ModelArguments:
freeze_feature_encoder: bool = field(
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
)
freeze_encoder: bool = field(
default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
)
forced_decoder_ids: List[List[int]] = field(
default=None,
metadata={
"help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
"will always be a token of index 123."
)
},
)
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)


@dataclass
Expand Down Expand Up @@ -187,14 +203,27 @@ class DataTrainingArguments:
default=True,
metadata={"help": "Whether the target text should be lower cased."},
)
language: str = field(
default=None,
metadata={
"help": (
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
"only. For English speech recognition, it should be set to `None`."
)
},
)
task: str = field(
default="transcribe",
metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
)


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor ([`Wav2Vec2Processor`])
processor ([`WhisperProcessor`])
The processor used for processing the data.
decoder_start_token_id (`int`)
The begin-of-sentence of the decoder.
Expand All @@ -206,7 +235,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
model_input_name = self.processor.model_input_names[0]
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]

batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
Expand Down Expand Up @@ -333,6 +363,8 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)

config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})

feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
Expand Down Expand Up @@ -360,6 +392,14 @@ def main():
if model_args.freeze_feature_encoder:
model.freeze_feature_encoder()

if model_args.freeze_encoder:
model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False

if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)

# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
if dataset_sampling_rate != feature_extractor.sampling_rate:
Expand Down Expand Up @@ -388,8 +428,8 @@ def prepare_dataset(batch):
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.input_values[0]
batch["input_length"] = len(batch["input_values"])
batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample["array"])

# process targets
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
Expand Down Expand Up @@ -452,7 +492,8 @@ def compute_metrics(pred):

# 10. Define data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor, decoder_start_token_id=model.config.decoder_start_token_id
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)

# 11. Initialize Trainer
Expand Down Expand Up @@ -492,7 +533,9 @@ def compute_metrics(pred):
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(
metric_key_prefix="eval", max_length=model.config.max_length, num_beams=model.config.num_beams
metric_key_prefix="eval",
max_length=training_args.generation_max_length,
num_beams=training_args.generation_num_beams,
)
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
Expand Down