diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ef13590e23f133..44305ee4dfbb22 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -607,13 +607,14 @@ def _set_signature_columns_if_needed(self): # Inspect model forward signature to keep only the arguments it accepts. signature = inspect.signature(self.model.forward) self._signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: return dataset self._set_signature_columns_if_needed() - # Labels may be named label or label_ids, the default data collator handles that. - signature_columns = self._signature_columns + ["label", "label_ids"] + signature_columns = self._signature_columns ignored_columns = list(set(dataset.column_names) - set(signature_columns)) if len(ignored_columns) > 0: @@ -642,7 +643,7 @@ def _get_collator_with_removed_columns( if not self.args.remove_unused_columns: return data_collator self._set_signature_columns_if_needed() - signature_columns = self._signature_columns + self.label_names + signature_columns = self._signature_columns remove_columns_collator = RemoveColumnsCollator( data_collator=data_collator, diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 8c76efa65ccb27..e418009af09e53 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -658,7 +658,7 @@ class FSDPOption(ExplicitEnum): class RemoveColumnsCollator: - """Wrap the data collator to remove unused columns from its output.""" + """Wrap the data collator to remove unused columns before they are passed to the collator.""" def __init__( self, @@ -690,4 +690,5 @@ def _remove_columns(self, feature: dict) -> dict: return {k: v for k, v in feature.items() if k in self.signature_columns} def __call__(self, features: List[dict]): - return self._remove_columns(self.data_collator(features)) + features = [self._remove_columns(feature) for feature in features] + return self.data_collator(features)