Skip to content

Fix DataCollatorForWholeWordMask #8379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 7, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 3 additions & 21 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __call__(
input_ids = examples
examples = [{"input_ids": e} for e in examples]

batch_input = self._tensorize_batch(input_ids)
batch_input = _collate_batch(input_ids, self.tokenizer)

mask_labels = []
for e in examples:
Expand All @@ -332,7 +332,7 @@ def __call__(
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = self._tensorize_batch(mask_labels)
batch_mask = _collate_batch(mask_labels, self.tokenizer)
inputs, labels = self.mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}

Expand Down Expand Up @@ -511,28 +511,10 @@ def __call__(
) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
batch = _collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}

def _tensorize_batch(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.Tensor(e) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)

def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
Expand Down