Skip to content

Commit

Permalink
Data collator for token classification pads labels column when receiv…
Browse files Browse the repository at this point in the history
…es pytorch tensors (#20244)

* token cls data_collator pads labels column

* remove walrus operator for code quality

* remove redundat space

* remove comment that was fixed

* PR comments fix

Co-authored-by: Alexander Markov <amarkov.me@gmail.com>
  • Loading branch information
Alexander Markov and markovalexander authored Nov 16, 2022
1 parent d4d2314 commit 610acc5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,30 +305,38 @@ def torch_call(self, features):

label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]

batch = self.tokenizer.pad(
features,
no_labels_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
return_tensors="pt",
)

if labels is None:
return batch

sequence_length = torch.tensor(batch["input_ids"]).shape[1]
sequence_length = batch["input_ids"].shape[1]
padding_side = self.tokenizer.padding_side

def to_list(tensor_or_iterable):
if isinstance(tensor_or_iterable, torch.Tensor):
return tensor_or_iterable.tolist()
return list(tensor_or_iterable)

if padding_side == "right":
batch[label_name] = [
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch[label_name] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
]

batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
return batch

def tf_call(self, features):
Expand Down
45 changes: 45 additions & 0 deletions tests/trainer/test_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,51 @@ def test_data_collator_for_token_classification(self):
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)

for feature in features:
feature.pop("labels")

batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)

def test_data_collator_for_token_classification_works_with_pt_tensors(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [
{"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([0, 1, 2])},
{"input_ids": torch.tensor([0, 1, 2, 3, 4, 5]), "labels": torch.tensor([0, 1, 2, 3, 4, 5])},
]

data_collator = DataCollatorForTokenClassification(tokenizer)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-100] * 3)

data_collator = DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))

data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
self.assertEqual(batch["labels"].shape, torch.Size([2, 8]))

data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)

for feature in features:
feature.pop("labels")

batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)

def _test_no_pad_and_pad(self, no_pad_features, pad_features):
tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
Expand Down

0 comments on commit 610acc5

Please sign in to comment.