Skip to content

Commit

Permalink
Make Trainer evaluation handle dynamic seq_length (huggingface#8336)
Browse files Browse the repository at this point in the history
* Make Trainer evaluation handle dynamic seq_length

* Document behavior.

* Fix test

* Better fix

* Fixes for realsies this time

* Address review comments

* Without forgetting to save...
  • Loading branch information
sgugger authored and fabiocapsouza committed Nov 15, 2020
1 parent 121c0e4 commit b8f23eb
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 13 deletions.
10 changes: 8 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,12 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput:
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
.. note::
If your predictions or labels have different sequence length (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
Returns: `NamedTuple` A namedtuple with the following keys:
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
Expand Down Expand Up @@ -1412,9 +1418,9 @@ def prediction_loop(
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
Expand Down
84 changes: 74 additions & 10 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,50 @@
logger = logging.get_logger(__name__)


def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
return torch.cat((tensor1, tensor2), dim=0)

# Let's figure out the new shape
new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]

# Now let's fill the result tensor
result = tensor1.new_full(new_shape, padding_index)
result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
return result


def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
return np.concatenate((array1, array2), dim=0)

# Let's figure out the new shape
new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]

# Now let's fill the result tensor
result = np.full_like(array1, padding_index, shape=new_shape)
result[: array1.shape[0], : array1.shape[1]] = array1
result[array1.shape[0] :, : array2.shape[1]] = array2
return result


def nested_concat(tensors, new_tensors, padding_index=-100):
"""
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
nested list/tuples of tensors.
"""
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
elif isinstance(tensors, torch.Tensor):
return torch.cat((tensors, new_tensors), dim=dim)
return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
elif isinstance(tensors, np.ndarray):
return np.concatenate((tensors, new_tensors), axis=dim)
return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
else:
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")

Expand Down Expand Up @@ -190,11 +223,21 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


def nested_new_like(arrays, num_samples):
def nested_new_like(arrays, num_samples, padding_index=-100):
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
if isinstance(arrays, (list, tuple)):
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype)
return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))


def nested_expand_like(arrays, new_seq_length, padding_index=-100):
""" Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
if isinstance(arrays, (list, tuple)):
return type(arrays)(nested_expand_like(x, new_seq_length, padding_index=padding_index) for x in arrays)

result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
result[:, : arrays.shape[1]] = arrays
return result


def nested_truncate(tensors, limit):
Expand All @@ -204,6 +247,13 @@ def nested_truncate(tensors, limit):
return tensors[:limit]


def _get_first_shape(arrays):
"""Return the shape of the first array found in the nested struct `arrays`."""
if isinstance(arrays, (list, tuple)):
return _get_first_shape(arrays[0])
return arrays.shape


class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
Expand Down Expand Up @@ -247,16 +297,19 @@ class DistributedTensorGatherer:
make_multiple_of (:obj:`int`, `optional`):
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
(by adding samples).
padding_index (:obj:`int`, `optional`, defaults to -100):
The padding index to use if the arrays don't all have the same sequence length.
"""

def __init__(self, world_size, num_samples, make_multiple_of=None):
def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
self.world_size = world_size
self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
self.process_length = self.total_samples // world_size
self._storage = None
self._offsets = None
self.padding_index = padding_index

def add_arrays(self, arrays):
"""
Expand All @@ -266,8 +319,14 @@ def add_arrays(self, arrays):
if arrays is None:
return
if self._storage is None:
self._storage = nested_new_like(arrays, self.total_samples)
self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index)
self._offsets = list(range(0, self.total_samples, self.process_length))
else:
storage_shape = _get_first_shape(self._storage)
arrays_shape = _get_first_shape(arrays)
if len(storage_shape) > 1 and storage_shape[1] < arrays_shape[1]:
# If we get new arrays that are too big too fit, we expand the shape fo the storage
self._storage = nested_expand_like(self._storage, arrays_shape[1], padding_index=self.padding_index)
slice_len = self._nested_set_tensors(self._storage, arrays)
for i in range(self.world_size):
self._offsets[i] += slice_len
Expand All @@ -283,7 +342,12 @@ def _nested_set_tensors(self, storage, arrays):

slice_len = arrays.shape[0] // self.world_size
for i in range(self.world_size):
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
if len(arrays.shape) == 1:
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
else:
storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
i * slice_len : (i + 1) * slice_len
]
return slice_len

def finalize(self):
Expand Down
54 changes: 53 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ def __getitem__(self, i):
return result


class DynamicShapesDataset:
def __init__(self, length=64, seed=42, batch_size=8):
self.length = length
np.random.seed(seed)
sizes = np.random.randint(1, 20, (length // batch_size,))
# For easy batching, we make every batch_size consecutive samples the same size.
self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]

def __len__(self):
return self.length

def __getitem__(self, i):
return {"input_x": self.xs[i], "labels": self.ys[i]}


class AlmostAccuracy:
def __init__(self, thresh=0.25):
self.thresh = thresh
Expand Down Expand Up @@ -282,7 +298,7 @@ def test_train_and_eval_dataloaders(self):
self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu))
self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu))

# Check passing a new dataset for evaluation wors
# Check passing a new dataset for evaluation works
new_eval_dataset = RegressionDataset(length=128)
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))

Expand Down Expand Up @@ -340,6 +356,42 @@ def test_predict(self):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

def test_dynamic_shapes(self):
eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
model = RegressionModel(a=2, b=1)
args = TrainingArguments("./regression")
trainer = Trainer(model, args, eval_dataset=eval_dataset)

# Check evaluation can run to completion
_ = trainer.evaluate()

# Check predictions
preds = trainer.predict(eval_dataset)
for expected, seen in zip(eval_dataset.ys, preds.label_ids):
self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

for expected, seen in zip(eval_dataset.xs, preds.predictions):
self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

# Same tests with eval accumulation
args = TrainingArguments("./regression", eval_accumulation_steps=2)
trainer = Trainer(model, args, eval_dataset=eval_dataset)

# Check evaluation can run to completion
_ = trainer.evaluate()

# Check predictions
preds = trainer.predict(eval_dataset)
for expected, seen in zip(eval_dataset.ys, preds.label_ids):
self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

for expected, seen in zip(eval_dataset.xs, preds.predictions):
self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

@require_datasets
def test_trainer_with_datasets(self):
import datasets
Expand Down

0 comments on commit b8f23eb

Please sign in to comment.