Skip to content

Commit

Permalink
Remove columns before passing to data collator (huggingface#17187)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored May 11, 2022
1 parent 934e21c commit 7b95825
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,13 +607,14 @@ def _set_signature_columns_if_needed(self):
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return dataset
self._set_signature_columns_if_needed()
# Labels may be named label or label_ids, the default data collator handles that.
signature_columns = self._signature_columns + ["label", "label_ids"]
signature_columns = self._signature_columns

ignored_columns = list(set(dataset.column_names) - set(signature_columns))
if len(ignored_columns) > 0:
Expand Down Expand Up @@ -642,7 +643,7 @@ def _get_collator_with_removed_columns(
if not self.args.remove_unused_columns:
return data_collator
self._set_signature_columns_if_needed()
signature_columns = self._signature_columns + self.label_names
signature_columns = self._signature_columns

remove_columns_collator = RemoveColumnsCollator(
data_collator=data_collator,
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ class FSDPOption(ExplicitEnum):


class RemoveColumnsCollator:
"""Wrap the data collator to remove unused columns from its output."""
"""Wrap the data collator to remove unused columns before they are passed to the collator."""

def __init__(
self,
Expand Down Expand Up @@ -690,4 +690,5 @@ def _remove_columns(self, feature: dict) -> dict:
return {k: v for k, v in feature.items() if k in self.signature_columns}

def __call__(self, features: List[dict]):
return self._remove_columns(self.data_collator(features))
features = [self._remove_columns(feature) for feature in features]
return self.data_collator(features)

0 comments on commit 7b95825

Please sign in to comment.