Skip to content

Commit

Permalink
Merge pull request #673 from helmholtz-analytics/features/178-tile
Browse files Browse the repository at this point in the history
Features/178 tile
  • Loading branch information
coquelin77 authored Aug 20, 2021
2 parents e0af0e5 + d4f450e commit e2f75c3
Show file tree
Hide file tree
Showing 3 changed files with 336 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ Example on 2 processes:
- [#664](https://github.com/helmholtz-analytics/heat/pull/664) New feature / enhancement: distributed `random.random_sample`, `random.random`, `random.sample`, `random.ranf`, `random.random_integer`
- [#666](https://github.com/helmholtz-analytics/heat/pull/666) New feature: distributed prepend/append for `diff()`.
- [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter
- [#678](https://github.com/helmholtz-analytics/heat/pull/678) New feature: distributed `tile`
- [#670](https://github.com/helmholtz-analytics/heat/pull/670) New Feature: `bincount()`
- [#674](https://github.com/helmholtz-analytics/heat/pull/674) New feature: `repeat`
- [#670](https://github.com/helmholtz-analytics/heat/pull/670) New Feature: distributed `bincount()`
- [#672](https://github.com/helmholtz-analytics/heat/pull/672) Bug / Enhancement: Remove `MPIRequest.wait()`, rewrite calls with capital letters. lower case `wait()` now falls back to the `mpi4py` function
Expand Down
257 changes: 257 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"squeeze",
"stack",
"swapaxes",
"tile",
"topk",
"unique",
"vsplit",
Expand Down Expand Up @@ -3596,6 +3597,262 @@ def vstack(arrays: Sequence[DNDarray, ...]) -> DNDarray:
return concatenate(arrays, axis=0)


def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray:
"""
Construct a new DNDarray by repeating 'x' the number of times given by 'reps'.
If 'reps' has length 'd', the result will have 'max(d, x.ndim)' dimensions:
- if 'x.ndim < d', 'x' is promoted to be d-dimensional by prepending new axes.
So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3)
for 3-D replication (if this is not the desired behavior, promote 'x' to d-dimensions
manually before calling this function);
- if 'x.ndim > d', 'reps' will replicate the last 'd' dimensions of 'x', i.e., if
'x.shape' is (2, 3, 4, 5), a 'reps' of (2, 2) will be expanded to (1, 1, 2, 2).
Parameters
----------
x : DNDarray
Input
reps : Sequence[ints,...]
Repetitions
Returns
-------
tiled : DNDarray
Split semantics: if `x` is distributed, the tiled data will be distributed along the
same dimension. Note that nominally `tiled.split != x.split` in the case where
`len(reps) > x.ndim`. See example below.
Examples
--------
>>> x = ht.arange(12).reshape((4,3)).resplit_(0)
>>> x
DNDarray([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]], dtype=ht.int32, device=cpu:0, split=0)
>>> reps = (1, 2, 2)
>>> tiled = ht.tile(x, reps)
>>> tiled
DNDarray([[[ 0, 1, 2, 0, 1, 2],
[ 3, 4, 5, 3, 4, 5],
[ 6, 7, 8, 6, 7, 8],
[ 9, 10, 11, 9, 10, 11],
[ 0, 1, 2, 0, 1, 2],
[ 3, 4, 5, 3, 4, 5],
[ 6, 7, 8, 6, 7, 8],
[ 9, 10, 11, 9, 10, 11]]], dtype=ht.int32, device=cpu:0, split=1)
"""
# x can be DNDarray or scalar
try:
_ = x.larray
except AttributeError:
try:
_ = x.shape
raise TypeError("Input can be a DNDarray or a scalar, is {}".format(type(x)))
except AttributeError:
x = factories.array(x).reshape(1)

x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim)

# torch-proof args/kwargs:
# torch `reps`: int or sequence of ints; numpy `reps`: can be array-like
try:
_ = x_proxy.repeat(reps)
except TypeError:
# `reps` is array-like or contains non-int elements
try:
reps = resplit(reps, None).tolist()
except AttributeError:
try:
reps = reps.tolist()
except AttributeError:
try:
_ = x_proxy.repeat(reps)
except TypeError:
raise TypeError(
"reps must be a sequence of ints, got {}".format(
list(type(i) for i in reps)
)
)
except RuntimeError:
pass
except RuntimeError:
pass

try:
reps = list(reps)
except TypeError:
# scalar to list
reps = [reps]

# torch reps vs. numpy reps: dimensions
if len(reps) != x.ndim:
added_dims = abs(len(reps) - x.ndim)
if len(reps) > x.ndim:
new_shape = added_dims * (1,) + x.gshape
new_split = None if x.split is None else x.split + added_dims
x = x.reshape(new_shape, new_split=new_split)
else:
reps = added_dims * [1] + reps

out_gshape = tuple(x_proxy.repeat(reps).shape)

if not x.is_distributed() or reps[x.split] == 1:
# no repeats along the split axis: local operation
t_tiled = x.larray.repeat(reps)
out_gshape = tuple(x_proxy.repeat(reps).shape)
return DNDarray(
t_tiled,
out_gshape,
dtype=x.dtype,
split=x.split,
device=x.device,
comm=x.comm,
balanced=x.balanced,
)
# repeats along the split axis, work along dim 0
size = x.comm.Get_size()
rank = x.comm.Get_rank()
trans_axes = list(range(x.ndim))
if x.split != 0:
trans_axes[0], trans_axes[x.split] = x.split, 0
reps[0], reps[x.split] = reps[x.split], reps[0]
x = linalg.transpose(x, trans_axes)
x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim)
out_gshape = tuple(x_proxy.repeat(reps).shape)

local_x = x.larray

# allocate tiled DNDarray, at first tiled along split axis only
split_reps = [rep if i == x.split else 1 for i, rep in enumerate(reps)]
split_tiled_shape = tuple(x_proxy.repeat(split_reps).shape)
tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=x.split, comm=x.comm)
# collect slicing information from all processes.
slices_map = []
for array in [x, tiled]:
counts, displs = array.counts_displs()
t_slices_starts = torch.tensor(displs, device=local_x.device)
t_slices_ends = t_slices_starts + torch.tensor(counts, device=local_x.device)
slices_map.append([t_slices_starts, t_slices_ends])

t_slices_x, t_slices_tiled = slices_map

# keep track of repetitions:
# local_x_starts.shape, local_x_ends.shape changing from (size,) to (reps[split], size)
reps_indices = list(x.gshape[x.split] * rep for rep in (range(reps[x.split])))
t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32, device=local_x.device).reshape(
len(reps_indices), 1
)
for i, t in enumerate(t_slices_x):
t = t.repeat((reps[x.split], 1))
t += t_reps_indices
t_slices_x[i] = t

# distribution logic on current rank:
distr_map = []
slices_map = []
for i in range(2):
if i == 0:
# send logic for x slices on rank
local_x_starts = t_slices_x[0][:, rank].reshape(reps[x.split], 1)
local_x_ends = t_slices_x[1][:, rank].reshape(reps[x.split], 1)
t_tiled_starts, t_tiled_ends = t_slices_tiled
else:
# recv logic for tiled slices on rank
local_x_starts, local_x_ends = t_slices_x
t_tiled_starts = t_slices_tiled[0][rank]
t_tiled_ends = t_slices_tiled[1][rank]
t_max_starts = torch.max(local_x_starts, t_tiled_starts)
t_min_ends = torch.min(local_x_ends, t_tiled_ends)
coords = torch.where(t_min_ends - t_max_starts > 0)
# remove repeat offset from slices if sending
if i == 0:
t_max_starts -= t_reps_indices
t_min_ends -= t_reps_indices
starts = t_max_starts[coords].unsqueeze_(0)
ends = t_min_ends[coords].unsqueeze_(0)
slices_map.append(torch.cat((starts, ends), dim=0))
distr_map.append(coords)

# bookkeeping in preparation for Alltoallv
send_map, recv_map = distr_map
send_rep, send_to_ranks = send_map
recv_rep, recv_from_ranks = recv_map
send_slices, recv_slices = slices_map

# do not assume that `x` is balanced
_, displs = x.counts_displs()
offset_x = displs[rank]
# impose load-balance on output
offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split)
t_tiled = tiled.larray

active_send_counts = send_slices.clone()
active_send_counts[0] *= -1
active_send_counts = active_send_counts.sum(0)
active_recv_counts = recv_slices.clone()
active_recv_counts[0] *= -1
active_recv_counts = active_recv_counts.sum(0)
send_slices -= offset_x
recv_slices -= offset_tiled
recv_buf = t_tiled.clone()
# we need as many Alltoallv calls as repeats along the split axis
for rep in range(reps[x.split]):
# send_data, send_counts, send_displs on rank
all_send_counts = [0] * size
all_send_displs = [0] * size
send_this_rep = torch.where(send_rep == rep)[0].tolist()
dest_this_rep = send_to_ranks[send_this_rep].tolist()
for i, j in zip(send_this_rep, dest_this_rep):
all_send_counts[j] = active_send_counts[i].item()
all_send_displs[j] = send_slices[0][i].item()
local_send_slice = [slice(None)] * x.ndim
local_send_slice[x.split] = slice(
all_send_displs[0], all_send_displs[0] + sum(all_send_counts)
)
send_buf = local_x[local_send_slice].clone()

# recv_data, recv_counts, recv_displs on rank
all_recv_counts = [0] * size
all_recv_displs = [0] * size
recv_this_rep = torch.where(recv_rep == rep)[0].tolist()
orig_this_rep = recv_from_ranks[recv_this_rep].tolist()
for i, j in zip(recv_this_rep, orig_this_rep):
all_recv_counts[j] = active_recv_counts[i].item()
all_recv_displs[j] = recv_slices[0][i].item()
local_recv_slice = [slice(None)] * x.ndim
local_recv_slice[x.split] = slice(
all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts)
)
x.comm.Alltoallv(
(send_buf, all_send_counts, all_send_displs),
(recv_buf, all_recv_counts, all_recv_displs),
)
t_tiled[local_recv_slice] = recv_buf[local_recv_slice]

# finally tile along non-split axes if needed
reps[x.split] = 1
tiled = DNDarray(
t_tiled.repeat(reps),
out_gshape,
dtype=x.dtype,
split=x.split,
device=x.device,
comm=x.comm,
balanced=True,
)
if trans_axes != list(range(x.ndim)):
# transpose back to original shape
x = linalg.transpose(x, trans_axes)
tiled = linalg.transpose(tiled, trans_axes)

return tiled


def topk(
a: DNDarray,
k: int,
Expand Down
77 changes: 77 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3267,6 +3267,83 @@ def test_swapaxes(self):
with self.assertRaises(TypeError):
ht.swapaxes(x, 4.9, "abc")

def test_tile(self):
# test local tile, tuple reps
x = ht.arange(12).reshape((4, 3))
reps = (2, 1)
ht_tiled = ht.tile(x, reps)
np_tiled = np.tile(x.numpy(), reps)
self.assertTrue((np_tiled == ht_tiled.numpy()).all())
self.assertTrue(ht_tiled.dtype is x.dtype)

# test scalar x
x = ht.array(9.0)
reps = (2, 1)
ht_tiled = ht.tile(x, reps)
np_tiled = np.tile(x.numpy(), reps)
self.assertTrue((np_tiled == ht_tiled.numpy()).all())
self.assertTrue(ht_tiled.dtype is x.dtype)

# test distributed tile along split axis
# len(reps) > x.ndim
split = 1
x = ht.random.randn(4, 3, split=split)
reps = ht.random.randint(2, 10, size=(4,))
tiled_along_split = ht.tile(x, reps)
np_tiled_along_split = np.tile(x.numpy(), reps.tolist())
self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all())
self.assertTrue(tiled_along_split.dtype is x.dtype)

# test distributed tile along non-zero split axis
# len(reps) > x.ndim
split = 0
x = ht.random.randn(4, 3, split=split)
reps = np.random.randint(2, 10, size=(4,))
tiled_along_split = ht.tile(x, reps)
np_tiled_along_split = np.tile(x.numpy(), reps)
self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all())
self.assertTrue(tiled_along_split.dtype is x.dtype)

# test distributed tile() on imbalanced DNDarray
x = ht.random.randn(100, split=0)
x = x[ht.where(x > 0)]
reps = 5
imbalanced_tiled_along_split = ht.tile(x, reps)
np_imbalanced_tiled_along_split = np.tile(x.numpy(), reps)
self.assertTrue(
(imbalanced_tiled_along_split.numpy() == np_imbalanced_tiled_along_split).all()
)
self.assertTrue(imbalanced_tiled_along_split.dtype is x.dtype)
self.assertTrue(imbalanced_tiled_along_split.is_balanced(force_check=True))

# test tile along non-split axis
# len(reps) < x.ndim
split = 1
x = ht.random.randn(4, 5, 3, 10, dtype=ht.float64, split=split)
reps = (2, 2)
tiled_along_non_split = ht.tile(x, reps)
np_tiled_along_non_split = np.tile(x.numpy(), reps)
self.assertTrue((tiled_along_non_split.numpy() == np_tiled_along_non_split).all())
self.assertTrue(tiled_along_non_split.dtype is x.dtype)

# test tile along split axis
# len(reps) = x.ndim
split = 1
x = ht.random.randn(3, 3, dtype=ht.float64, split=split)
reps = (2, 3)
tiled_along_split = ht.tile(x, reps)
np_tiled_along_split = np.tile(x.numpy(), reps)
self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all())
self.assertTrue(tiled_along_split.dtype is x.dtype)

# test exceptions
float_reps = (1, 2, 2, 1.5)
with self.assertRaises(TypeError):
tiled_along_split = ht.tile(x, float_reps)
arraylike_float_reps = torch.tensor(float_reps)
with self.assertRaises(TypeError):
tiled_along_split = ht.tile(x, arraylike_float_reps)

def test_topk(self):
size = ht.MPI_WORLD.size
if size == 1:
Expand Down

0 comments on commit e2f75c3

Please sign in to comment.