Skip to content

Commit

Permalink
Ensure tensors are at least 1d for pad and concat (huggingface#17179)
Browse files Browse the repository at this point in the history
* Ensure tensors are at least 1d for pad and concat

* Compatibility

* Fix

* Fix

* Add test

* Retrigger CI

* Consistency with master

* Retrigger CI
  • Loading branch information
Yard1 authored May 11, 2022
1 parent c76afa5 commit 47412c7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
25 changes: 20 additions & 5 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,22 @@
logger = logging.get_logger(__name__)


def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
if isinstance(tensor_or_array, torch.Tensor):
if hasattr(torch, "atleast_1d"):
tensor_or_array = torch.atleast_1d(tensor_or_array)
elif tensor_or_array.ndim < 1:
tensor_or_array = tensor_or_array[None]
else:
tensor_or_array = np.atleast_1d(tensor_or_array)
return tensor_or_array


def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
tensor1 = atleast_1d(tensor1)
tensor2 = atleast_1d(tensor2)

if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
return torch.cat((tensor1, tensor2), dim=0)

Expand All @@ -72,6 +86,9 @@ def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):

def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
array1 = atleast_1d(array1)
array2 = atleast_1d(array2)

if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
return np.concatenate((array1, array2), axis=0)

Expand Down Expand Up @@ -149,8 +166,7 @@ def nested_xla_mesh_reduce(tensors, name):

if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
if tensors.ndim == 0:
tensors = tensors[None]
tensors = atleast_1d(tensors)
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
Expand All @@ -160,8 +176,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
if len(tensor.shape) <= 0:
tensor = tensor[None]
tensor = atleast_1d(tensor)
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
Expand Down Expand Up @@ -1031,7 +1046,7 @@ def smp_gather(tensor):
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
all_tensors = [atleast_1d(t) for t in all_tensors]
return torch.cat([t.cpu() for t in all_tensors], dim=0)

def smp_nested_concat(tensor):
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
SequentialDistributedSampler,
ShardSampler,
get_parameter_names,
numpy_pad_and_concatenate,
torch_pad_and_concatenate,
)

class TstLayer(nn.Module):
Expand Down Expand Up @@ -459,6 +461,18 @@ def mock_training_loop_function(batch_size):
mock_training_loop_function()
self.assertEqual("CUDA out of memory", cm.args[0])

def test_pad_and_concatenate_with_1d(self):
"""Tests whether pad_and_concatenate works with scalars."""
array1 = 1.0
array2 = 2.0
result = numpy_pad_and_concatenate(array1, array2)
self.assertTrue(np.array_equal(np.array([1.0, 2.0]), result))

tensor1 = torch.tensor(1.0)
tensor2 = torch.tensor(2.0)
result = torch_pad_and_concatenate(tensor1, tensor2)
self.assertTrue(torch.equal(result, torch.Tensor([1.0, 2.0])))

def test_remove_columns_collator(self):
class MockLogger:
def __init__(self) -> None:
Expand Down

0 comments on commit 47412c7

Please sign in to comment.