Skip to content

Commit

Permalink
final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Nov 14, 2022
1 parent 6fa9c0e commit 0f83fb7
Showing 1 changed file with 43 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,19 @@ class ModelArguments:
freeze_encoder: bool = field(
default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
)
forced_decoder_ids: list = field(
default=None,
metadata={
"help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
"will always be a token of index 123."
)
},
)
suppress_tokens: list = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)


@dataclass
Expand Down Expand Up @@ -190,6 +203,19 @@ class DataTrainingArguments:
default=True,
metadata={"help": "Whether the target text should be lower cased."},
)
language: str = field(
default=None,
metadata={
"help": (
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
"only. For English speech recognition, it should be set to `None`."
)
},
)
task: str = field(
default="transcribe",
metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
)


@dataclass
Expand All @@ -201,40 +227,30 @@ class DataCollatorSpeechSeq2SeqWithPadding:
The processor used for processing the data.
decoder_start_token_id (`int`)
The begin-of-sentence of the decoder.
eos_token_id (`int`)
The end-of-sentence of the model.
model_input_name (`str`)
Name of the pre-processed audio inputs expected by the model.
"""

processor: Any
decoder_start_token_id: int
eos_token_id: int
model_input_name: str

def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lenghts and need different padding methods
# first treat the audio inputs by padding to max length
input_features = [{self.model_input_name: feature[self.model_input_name]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# split inputs and labels since they have to be of different lengths and need
# different padding methods
model_input_name = self.processor.model_input_names[0]
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]

# now handle the target labels
for feature in features:
# if bos token is prepended in previous tokenization step,
# cut bos token here as it's prepended later anyways
if feature["labels"][0] == self.decoder_start_token_id:
feature["labels"] = feature["labels"][1:]
# if eos token is not appended in previous tokenization step,
# append eos token here as it's not appended later
if feature["labels"][-1] != self.eos_token_id and self.eos_token_id is not None:
feature["labels"].append(self.eos_token_id)
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]

batch["labels"] = labels

return batch
Expand Down Expand Up @@ -347,6 +363,8 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)

config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})

feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
Expand Down Expand Up @@ -378,6 +396,10 @@ def main():
model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False

if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)

# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
if dataset_sampling_rate != feature_extractor.sampling_rate:
Expand Down Expand Up @@ -472,8 +494,6 @@ def compute_metrics(pred):
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
eos_token_id=model.config.eos_token_id,
model_input_name=model_input_name,
)

# 11. Initialize Trainer
Expand Down

0 comments on commit 0f83fb7

Please sign in to comment.