Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sort() #312

Merged
merged 40 commits into from
Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
02aae24
implemented parallel sort step 1 - 4
TheSlimvReal May 15, 2019
1aac9d3
rewrote steps the work with multidimensional sorting
TheSlimvReal May 15, 2019
f37956b
fixed typo
TheSlimvReal May 28, 2019
fae549c
Merge remote-tracking branch 'origin/master' into feature/144-sort
Jun 4, 2019
7fc0bb2
added sort to the list of available functions
Jun 5, 2019
2b815d8
Calculating the number of elements each process at every position get…
Jun 14, 2019
4030950
added todo for next steps
Jun 18, 2019
923aaec
Values are distributed to the correct process, now the processes need…
Jun 18, 2019
2581ca9
redistributing values works, cases where most values are at the begin…
Jun 18, 2019
05fbf88
descending works too
Jun 21, 2019
e24779d
working on special treatment for empty local data
Jun 21, 2019
86c540f
Merge branch 'bugfix/MPI_IN_PLACE' into feature/144-sort
Jun 21, 2019
f4b5476
Finished tests for sort function
Jun 24, 2019
3e66bb4
fixed the out parameter and added test for out buffer
Jun 24, 2019
41810f6
fixed a unit test
Jun 24, 2019
2f26036
indices of the elements in the original data are now returned as well
Jun 25, 2019
acb0181
removed debugging code
Jun 25, 2019
12d9b90
Fixed a bug
Jun 25, 2019
36e4838
Added documentation
Jun 25, 2019
fa4057e
Merge branch 'master' into feature/144-sort
Markus-Goetz Jun 27, 2019
3117f26
Improved performance by using less communication calls
Jun 28, 2019
6629de0
removed debugging code
Jun 28, 2019
3b419cb
Merge remote-tracking branch 'origin/feature/144-sort' into feature/1…
Jun 28, 2019
3b15d38
Merge branch 'master' into feature/144-sort
Markus-Goetz Jul 1, 2019
9d5655d
Merge branch 'master' into feature/144-sort
Markus-Goetz Jul 1, 2019
64153c4
first sharing of values is now done with alltoallv which has less com…
Jul 3, 2019
42e3070
fixed the tests and removed some bugs
Jul 12, 2019
65ba89b
fixed a bug on re-balancing
Jul 12, 2019
c0ed6fa
fixed a bug in the send_vec building process
Jul 12, 2019
f1cc8d2
removed the debug code
Jul 12, 2019
7522356
Merge branch 'master' into feature/144-sort
TheSlimvReal Jul 12, 2019
983020a
removed a torch to numpy conversion
Jul 12, 2019
7e8cae2
Merge branch 'feature/144-sort' of https://github.com/helmholtz-analy…
Jul 12, 2019
7eef116
Added some more comments and test cases
Jul 12, 2019
b6abcf3
Merge branch 'master' into feature/144-sort
Markus-Goetz Jul 22, 2019
dde7937
Merge branch 'master' into feature/144-sort
TheSlimvReal Jul 29, 2019
7472550
fixed the type of the created indices
Jul 29, 2019
5988391
removed debug code
Jul 29, 2019
84c6733
Merge branch 'master' into feature/144-sort
Markus-Goetz Aug 2, 2019
0ef42d7
Merge branch 'master' into feature/144-sort
Markus-Goetz Aug 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added heat/core/manipulation.py
Empty file.
269 changes: 267 additions & 2 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
import operator

import numpy as np
import torch
from mpi4py import MPI

from .communication import MPI

Expand All @@ -11,6 +14,7 @@
__all__ = [
'concatenate',
'expand_dims',
'sort',
'squeeze',
'unique'
]
Expand Down Expand Up @@ -312,7 +316,7 @@ def expand_dims(a, axis):
>>> y.shape
(1, 2)

y = ht.expand_dims(x, axis=1)
>>> y = ht.expand_dims(x, axis=1)
>>> y
array([[1],
[2]])
Expand All @@ -335,6 +339,267 @@ def expand_dims(a, axis):
)


def sort(a, axis=None, descending=False, out=None):
"""
Sorts the elements of the DNDarray a along the given dimension (by default in ascending order) by their value.

The sorting is not stable which means that equal elements in the result may have a different ordering than in the
original array.

Sorting where `axis == a.split` needs a lot of communication between the processes of MPI.

Parameters
----------
a : ht.DNDarray
Input array to be sorted.
axis : int, optional
The dimension to sort along.
Default is the last axis.
descending : bool, optional
If set to true values are sorted in descending order
Default is false
out : ht.DNDarray or None, optional
A location in which to store the results. If provided, it must have a broadcastable shape. If not provided
or set to None, a fresh tensor is allocated.

Returns
-------
values : ht.DNDarray
The sorted local results.
indices
The indices of the elements in the original data

Raises
------
ValueError
If the axis is not in range of the axes.

Examples
--------
>>> x = ht.array([[4, 1], [2, 3]], split=0)
>>> x.shape
(1, 2)
(1, 2)

>>> y = ht.sort(x, axis=0)
>>> y
(array([[2, 1]], array([[1, 0]]))
(array([[4, 3]], array([[0, 1]]))

>>> ht.sort(x, descending=True)
(array([[4, 1]], array([[0, 1]]))
(array([[3, 2]], array([[1, 0]]))
"""
# default: using last axis
if axis is None:
axis = len(a.shape) - 1

stride_tricks.sanitize_axis(a.shape, axis)

if a.split is None or axis != a.split:
# sorting is not affected by split -> we can just sort along the axis
final_result, final_indices = torch.sort(a._DNDarray__array, dim=axis, descending=descending)

else:
# sorting is affected by split, processes need to communicate results
# transpose so we can work along the 0 axis
transposed = a._DNDarray__array.transpose(axis, 0)
TheSlimvReal marked this conversation as resolved.
Show resolved Hide resolved
local_sorted, local_indices = torch.sort(transposed, dim=0, descending=descending)

size = a.comm.Get_size()
rank = a.comm.Get_rank()
counts, disp, _ = a.comm.counts_displs_shape(a.gshape, axis=axis)

actual_indices = local_indices.to(dtype=local_sorted.dtype) + disp[rank]

length = local_sorted.size()[0]

# Separate the sorted tensor into size + 1 equal length partitions
partitions = [x * length // (size + 1) for x in range(1, size + 1)]
local_pivots = local_sorted[partitions] if counts[rank] else torch.empty(
(0, ) + local_sorted.size()[1:], dtype=local_sorted.dtype)

# Only processes with elements should share their pivots
gather_counts = [int(x > 0) * size for x in counts]
gather_displs = (0, ) + tuple(np.cumsum(gather_counts[:-1]))

pivot_dim = list(transposed.size())
pivot_dim[0] = size * sum([1 for x in counts if x > 0])

# share the local pivots with root process
pivot_buffer = torch.empty(pivot_dim, dtype=a.dtype.torch_type())
a.comm.Gatherv(local_pivots, (pivot_buffer, gather_counts, gather_displs), root=0)

pivot_dim[0] = size - 1
global_pivots = torch.empty(pivot_dim, dtype=a.dtype.torch_type())

# root process creates new pivots and shares them with other processes
if rank is 0:
sorted_pivots, _ = torch.sort(pivot_buffer, descending=descending, dim=0)
length = sorted_pivots.size()[0]
global_partitions = [x * length // size for x in range(1, size)]
global_pivots = sorted_pivots[global_partitions]

a.comm.Bcast(global_pivots, root=0)

lt_partitions = torch.empty((size, ) + local_sorted.shape, dtype=torch.int64)
last = torch.zeros_like(local_sorted, dtype=torch.int64)
comp_op = torch.gt if descending else torch.lt
# Iterate over all pivots and store which pivot is the first greater than the elements value
for idx, p in enumerate(global_pivots):
lt = comp_op(local_sorted, p)
if idx > 0:
lt_partitions[idx] = lt - last
else:
lt_partitions[idx] = lt
last = lt
lt_partitions[size - 1] = torch.ones_like(local_sorted, dtype=last.dtype) - last

# Matrix holding information how many values will be sent where
local_partitions = torch.sum(lt_partitions, dim=1)

partition_matrix = torch.empty_like(local_partitions)
a.comm.Allreduce(local_partitions, partition_matrix, op=MPI.SUM)

# Matrix that holds information which value will be shipped where
index_matrix = torch.empty_like(local_sorted, dtype=torch.int64)

# Matrix holding information which process get how many values from where
shape = (size, ) + transposed.size()[1:]
send_recv_matrix = torch.zeros(shape, dtype=partition_matrix.dtype)

for i, x in enumerate(lt_partitions):
index_matrix[x > 0] = i
send_recv_matrix[i] += torch.sum(x, dim=0)

a.comm.Alltoall(MPI.IN_PLACE, send_recv_matrix)

scounts = local_partitions
rcounts = send_recv_matrix

shape = (partition_matrix[rank].max(), ) + transposed.size()[1:]
first_result = torch.empty(shape, dtype=local_sorted.dtype)
first_indices = torch.empty_like(first_result)

# Iterate through one layer and send values with alltoallv
for idx in np.ndindex(local_sorted.shape[1:]):
idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx]

send_count = scounts[idx_slice].reshape(-1).tolist()
send_disp = [0] + list(np.cumsum(send_count[:-1]))
s_val = torch.tensor(local_sorted[idx_slice])
s_ind = torch.tensor(actual_indices[idx_slice])

recv_count = rcounts[idx_slice].reshape(-1).tolist()
recv_disp = [0] + list(np.cumsum(recv_count[:-1]))
rcv_length = rcounts[idx_slice].sum().item()
r_val = torch.empty((rcv_length, ) + s_val.shape[1:], dtype=local_sorted.dtype)
r_ind = torch.empty_like(r_val)

a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp))
a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp))
first_result[idx_slice][:rcv_length] = r_val
first_indices[idx_slice][:rcv_length] = r_ind

# The process might not have the correct number of values therefore the tensors need to be rebalanced
send_vec = torch.zeros(local_sorted.shape[1:] + (size, size), dtype=torch.int64)
target_cumsum = np.cumsum(counts)
for idx in np.ndindex(local_sorted.shape[1:]):
idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx]
current_counts = partition_matrix[idx_slice].reshape(-1).tolist()
current_cumsum = list(np.cumsum(current_counts))
for proc in range(size):
if current_cumsum[proc] > target_cumsum[proc]:
# process has to many values which will be sent to higher ranks
first = next(i for i in range(size) if send_vec[idx][:, i].sum() < counts[i])
last = next(i for i in range(size + 1) if i == size or current_cumsum[proc] < target_cumsum[i])
sent = 0
for i, x in enumerate(counts[first: last]):
# Each following process gets as many elements as it needs
amount = int(x - send_vec[idx][:, first + i].sum())
send_vec[idx][proc][first + i] = amount
current_counts[first + i] += amount
sent += amount
if last < size:
# Send all left over values to the highest last process
amount = partition_matrix[proc][idx]
send_vec[idx][proc][last] = int(amount - sent)
current_counts[last] += int(amount - sent)
elif current_cumsum[proc] < target_cumsum[proc]:
# process needs values from higher rank
first = 0 if proc == 0 else next(i for i, x in enumerate(current_cumsum)
if target_cumsum[proc - 1] < x)
last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x)
for i, x in enumerate(partition_matrix[idx_slice][first: last]):
# Taking as many elements as possible from each following process
send_vec[idx][first + i][proc] = int(x - send_vec[idx][first + i].sum())
current_counts[first + i] = 0
# Taking just enough elements from the last element to fill the current processes tensor
send_vec[idx][last][proc] = int(target_cumsum[proc] - current_cumsum[last - 1])
current_counts[last] -= int(target_cumsum[proc] - current_cumsum[last - 1])
else:
# process doesn't need more values
send_vec[idx][proc][proc] = partition_matrix[proc][idx] - send_vec[idx][proc].sum()
current_counts[proc] = counts[proc]
current_cumsum = list(np.cumsum(current_counts))

# Iterate through one layer again to create the final balanced local tensors
second_result = torch.empty_like(local_sorted)
second_indices = torch.empty_like(second_result)
for idx in np.ndindex(local_sorted.shape[1:]):
idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx]

send_count = send_vec[idx][rank]
send_disp = [0] + list(np.cumsum(send_count[:-1]))

recv_count = send_vec[idx][:, rank]
recv_disp = [0] + list(np.cumsum(recv_count[:-1]))

end = partition_matrix[rank][idx]
s_val, indices = first_result[0: end][idx_slice].sort(descending=descending, dim=0)
s_ind = first_indices[0: end][idx_slice][indices].reshape_as(s_val)

r_val = torch.empty((counts[rank], ) + s_val.shape[1:], dtype=local_sorted.dtype)
r_ind = torch.empty_like(r_val)

a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp))
a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp))

second_result[idx_slice] = r_val
second_indices[idx_slice] = r_ind

# print('second_result', second_result, 'tmp_indices', second_indices)

second_result, tmp_indices = second_result.sort(dim=0, descending=descending)
final_result = second_result.transpose(0, axis)
final_indices = torch.empty_like(second_indices)
# Update the indices in case the ordering changed during the last sort
for idx in np.ndindex(tmp_indices.shape):
val = tmp_indices[idx]
final_indices[idx] = second_indices[val][idx[1:]]
final_indices = final_indices.transpose(0, axis)

return_indices = factories.array(
final_indices,
dtype=dndarray.types.int32,
is_split=a.split,
device=a.device,
comm=a.comm
)
if out is not None:
out._DNDarray__array = final_result
return return_indices
else:
tensor = factories.array(
final_result,
dtype=a.dtype,
is_split=a.split,
device=a.device,
comm=a.comm
)
return tensor, return_indices


def squeeze(x, axis=None):
"""
Remove single-dimensional entries from the shape of a tensor.
Expand Down
Loading