From 0f83fb79ed82a450e1281cac3e62af6d931da9e9 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Mon, 14 Nov 2022 11:37:49 +0000 Subject: [PATCH] final fixes --- .../run_speech_recognition_seq2seq.py | 66 ++++++++++++------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py index 33c1d20bbee934..6056720a6f025f 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py @@ -100,6 +100,19 @@ class ModelArguments: freeze_encoder: bool = field( default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."} ) + forced_decoder_ids: list = field( + default=None, + metadata={ + "help": ( + "A list of pairs of integers which indicates a mapping from generation indices to token indices " + "that will be forced before sampling. For example, [[0, 123]] means the first generated token " + "will always be a token of index 123." + ) + }, + ) + suppress_tokens: list = field( + default=None, metadata={"help": "A list of tokens that will be suppressed at generation."} + ) @dataclass @@ -190,6 +203,19 @@ class DataTrainingArguments: default=True, metadata={"help": "Whether the target text should be lower cased."}, ) + language: str = field( + default=None, + metadata={ + "help": ( + "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning " + "only. For English speech recognition, it should be set to `None`." + ) + }, + ) + task: str = field( + default="transcribe", + metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."}, + ) @dataclass @@ -201,40 +227,30 @@ class DataCollatorSpeechSeq2SeqWithPadding: The processor used for processing the data. decoder_start_token_id (`int`) The begin-of-sentence of the decoder. - eos_token_id (`int`) - The end-of-sentence of the model. - model_input_name (`str`) - Name of the pre-processed audio inputs expected by the model. """ processor: Any decoder_start_token_id: int - eos_token_id: int - model_input_name: str def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: - # split inputs and labels since they have to be of different lenghts and need different padding methods - # first treat the audio inputs by padding to max length - input_features = [{self.model_input_name: feature[self.model_input_name]} for feature in features] - batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") + # split inputs and labels since they have to be of different lengths and need + # different padding methods + model_input_name = self.processor.model_input_names[0] + input_features = [{model_input_name: feature[model_input_name]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] - # now handle the target labels - for feature in features: - # if bos token is prepended in previous tokenization step, - # cut bos token here as it's prepended later anyways - if feature["labels"][0] == self.decoder_start_token_id: - feature["labels"] = feature["labels"][1:] - # if eos token is not appended in previous tokenization step, - # append eos token here as it's not appended later - if feature["labels"][-1] != self.eos_token_id and self.eos_token_id is not None: - feature["labels"].append(self.eos_token_id) + batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") - label_features = [{"input_ids": feature["labels"]} for feature in features] labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") # replace padding with -100 to ignore loss correctly labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + # if bos token is appended in previous tokenization step, + # cut bos token here as it's append later anyways + if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): + labels = labels[:, 1:] + batch["labels"] = labels return batch @@ -347,6 +363,8 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) + config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens}) + feature_extractor = AutoFeatureExtractor.from_pretrained( model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, @@ -378,6 +396,10 @@ def main(): model.freeze_encoder() model.model.encoder.gradient_checkpointing = False + if data_args.language is not None: + # We only need to set the task id when the language is specified (i.e. in a multilingual setting) + tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) + # 6. Resample speech dataset if necessary dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate if dataset_sampling_rate != feature_extractor.sampling_rate: @@ -472,8 +494,6 @@ def compute_metrics(pred): data_collator = DataCollatorSpeechSeq2SeqWithPadding( processor=processor, decoder_start_token_id=model.config.decoder_start_token_id, - eos_token_id=model.config.eos_token_id, - model_input_name=model_input_name, ) # 11. Initialize Trainer