Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different lshape maps in binary ops #887

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ These tools only profile the memory used by each process, not the entire functio
- with `split=None` and `split not None`

Python has an embedded profiler: https://docs.python.org/3.9/library/profile.html
Again, this will only provile the performance on each process. Printing the results with many processes
Again, this will only profile the performance on each process. Printing the results with many processes
my be illegible. It may be easiest to save the output of each to a file.
--->

Expand Down
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element.

## Feature Additions
### Linear Algebra
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`

### Arithmetics
- - [#887](https://github.com/helmholtz-analytics/heat/pull/887) Binary operations now support operands of equal shapes, equal `split` axes, but different distribution maps.

### Communication
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split`
Expand All @@ -21,6 +22,7 @@

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
### Logical
- [#862](https://github.com/helmholtz-analytics/heat/pull/862) New feature `signbit`
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/tests/test_kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_spherical_dataset(
cluster4 = ht.stack((x - 2 * offset, y - 2 * offset, z - 2 * offset), axis=1)

data = ht.concatenate((cluster1, cluster2, cluster3, cluster4), axis=0)
# Note: enhance when shuffel is available
# Note: enhance when shuffle is available
return data

def test_clusterer(self):
Expand Down
120 changes: 89 additions & 31 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def __binary_op(
-------
result: ht.DNDarray
A DNDarray containing the results of element-wise operation.

Warning
-------
If both operands are distributed, they must be distributed along the same dimension, i.e. `t1.split = t2.split`.

MPI communication is necessary when both operands are distributed along the same dimension, but the distribution maps do not match. E.g.:
```
a = ht.ones(10000, split=0)
b = ht.zeros(10000, split=0)
c = a[:-1] + b[1:]
```
In such cases, one of the operands is redistributed IN PLACE to match the distribution map of the other operand.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why in-place? This means that a binop can manipulate its arguments

"""
promoted_type = types.result_type(t1, t2).torch_type()

Expand Down Expand Up @@ -97,20 +109,43 @@ def __binary_op(

elif isinstance(t2, DNDarray):
if t1.split is None:
t1 = factories.array(
t1, split=t2.split, copy=False, comm=t1.comm, device=t1.device, ndmin=-t2.ndim
)
# do not distribute t1 if size along distribution axis is 1
if (
t2.split is not None
and t2.split >= abs(t2.ndim - t1.ndim)
and t1.shape[t2.split - abs(t2.ndim - t1.ndim)] > 1
):
t1 = factories.array(
t1,
split=t2.split,
copy=False,
comm=t1.comm,
device=t1.device,
ndmin=-t2.ndim,
)
output_split = t2.split
elif t2.split is None:
t2 = factories.array(
t2, split=t1.split, copy=False, comm=t2.comm, device=t2.device, ndmin=-t1.ndim
)
# do not distribute t2 if size along distribution axis is 1
if (
t1.split is not None
and t1.split >= abs(t2.ndim - t1.ndim)
and t2.shape[t1.split - abs(t2.ndim - t1.ndim)] > 1
):
t2 = factories.array(
t2,
split=t1.split,
copy=False,
comm=t2.comm,
device=t2.device,
ndmin=-t1.ndim,
)
output_split = t1.split
elif t1.split != t2.split:
# It is NOT possible to perform binary operations on tensors with different splits, e.g. split=0
# and split=1
raise NotImplementedError("Not implemented for other splittings")

output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape)
output_split = t1.split
output_device = t1.device
output_comm = t1.comm

Expand All @@ -119,57 +154,79 @@ def __binary_op(
# warnings.warn(
# "Broadcasting requires transferring data of first operator between MPI ranks!"
# )
color = 0 if t1.comm.rank < t2.shape[t1.split] else 1
newcomm = t1.comm.Split(color, t1.comm.rank)
if t1.comm.rank > 0 and color == 0:
t1.larray = torch.zeros(
t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device
)
newcomm.Bcast(t1)
newcomm.Free()

t1.resplit_(None)
# color = 0 if t1.comm.rank < t2.shape[t1.split] else 1
# newcomm = t1.comm.Split(color, t1.comm.rank)
# if t1.comm.rank > 0 and color == 0:
# t1.larray = torch.zeros(
# t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device
# )
# newcomm.Bcast(t1)
# newcomm.Free()
# t1.__lshape = t1.gshape
# t1.__split = None
# t1.__lshape_map = torch.tensor(t1.gshape).unsqueeze_(0)
# t1.__balanced = True
if t2.split is not None:
if t2.shape[t2.split] == 1 and t2.comm.is_distributed():
# warnings.warn(
# "Broadcasting requires transferring data of second operator between MPI ranks!"
# )
color = 0 if t2.comm.rank < t1.shape[t2.split] else 1
newcomm = t2.comm.Split(color, t2.comm.rank)
if t2.comm.rank > 0 and color == 0:
t2.larray = torch.zeros(
t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device
)
newcomm.Bcast(t2)
newcomm.Free()

t2.resplit_(None)
# color = 0 if t2.comm.rank < t1.shape[t2.split] else 1
# newcomm = t2.comm.Split(color, t2.comm.rank)
# if t2.comm.rank > 0 and color == 0:
# t2.larray = torch.zeros(
# t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device
# )
# newcomm.Bcast(t2)
# newcomm.Free()
# t2.__lshape = t2.gshape
# t2.__split = None
# t2.__lshape_map = torch.tensor(t2.gshape).unsqueeze_(0)
# t2.__balanced = True
else:
raise TypeError(
"Only tensors and numeric scalars are supported, but input was {}".format(type(t2))
)
else:
raise NotImplementedError("Not implemented for non scalar")

# sanitize output
if out is not None:
sanitation.sanitize_out(out, output_shape, output_split, output_device)

# promoted_type = types.promote_types(t1.dtype, t2.dtype).torch_type()
if t1.split is not None:
output_split = t1.split
# TODO: implement `dndarray.create_bulk_lshape_maps`
t1.create_lshape_map()
t2.create_lshape_map()
if len(t1.lshape) > t1.split and t1.lshape[t1.split] == 0:
result = t1.larray.type(promoted_type)
else:
if (
t2.split is not None
and not (t2.lshape_map[:, t2.split] == t1.lshape_map[:, t1.split]).all()
):
t2.redistribute_(target_map=t1.lshape_map)
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
)
elif t2.split is not None:

elif t2.split is not None:
output_split = t2.split
# TODO: implement `dndarray.create_bulk_lshape_maps`
t1.create_lshape_map()
t2.create_lshape_map()
if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0:
result = t2.larray.type(promoted_type)
else:
if (
t1.split is not None
and not (t2.lshape_map[:, t2.split] == t1.lshape_map[:, t1.split]).all()
):
t1.redistribute_(target_map=t2.lshape_map)
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
)
else:
output_split = None
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
)
Expand All @@ -178,6 +235,7 @@ def __binary_op(
result = torch.tensor(result, device=output_device.torch_device)

if out is not None:
sanitation.sanitize_out(out, output_shape, output_split, output_device)
out_dtype = out.dtype
out.larray = result
out._DNDarray__comm = output_comm
Expand Down
42 changes: 25 additions & 17 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def setUpClass(cls):
cls.another_tensor = ht.array([[2.0, 2.0], [2.0, 2.0]])
cls.a_split_tensor = cls.another_tensor.copy().resplit_(0)

cls.errorneous_type = (2, 2)
cls.erroneous_type = (2, 2)

def test_add(self):
# test basics
result = ht.array([[3.0, 4.0], [5.0, 6.0]])

self.assertTrue(ht.equal(ht.add(self.a_scalar, self.a_scalar), ht.float32(4.0)))
Expand All @@ -45,10 +46,17 @@ def test_add(self):
else:
self.assertEqual(c.larray.size()[0], 0)

# test with differently distributed DNDarrays
a = ht.ones(10, split=0)
b = ht.zeros(10, split=0)
c = a[:-1] + b[1:]
self.assertTrue((c == 1).all())
self.assertTrue(c.lshape == a[:-1].lshape)

with self.assertRaises(ValueError):
ht.add(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.add(self.a_tensor, self.errorneous_type)
ht.add(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.add("T", "s")

Expand Down Expand Up @@ -76,7 +84,7 @@ def test_bitwise_and(self):
with self.assertRaises(ValueError):
ht.bitwise_and(an_int_vector, another_int_vector)
with self.assertRaises(TypeError):
ht.bitwise_and(self.a_tensor, self.errorneous_type)
ht.bitwise_and(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.bitwise_and("T", "s")
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -112,7 +120,7 @@ def test_bitwise_or(self):
with self.assertRaises(ValueError):
ht.bitwise_or(an_int_vector, another_int_vector)
with self.assertRaises(TypeError):
ht.bitwise_or(self.a_tensor, self.errorneous_type)
ht.bitwise_or(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.bitwise_or("T", "s")
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -148,7 +156,7 @@ def test_bitwise_xor(self):
with self.assertRaises(ValueError):
ht.bitwise_xor(an_int_vector, another_int_vector)
with self.assertRaises(TypeError):
ht.bitwise_xor(self.a_tensor, self.errorneous_type)
ht.bitwise_xor(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.bitwise_xor("T", "s")
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -257,17 +265,16 @@ def test_cumsum(self):
def test_diff(self):
ht_array = ht.random.rand(20, 20, 20, split=None)
arb_slice = [0] * 3
for dim in range(3): # loop over 3 dimensions
for dim in range(0, 3): # loop over 3 dimensions
arb_slice[dim] = slice(None)
tup_arb = tuple(arb_slice)
np_array = ht_array[tup_arb].numpy()
for ax in range(dim + 1): # loop over the possible axis values
for sp in range(dim + 1): # loop over the possible split values
lp_array = ht.manipulations.resplit(ht_array[tup_arb], sp)
# loop to 3 for the number of times to do the diff
for nl in range(1, 4):
# only generating the number once and then
tup_arb = tuple(arb_slice)
lp_array = ht.manipulations.resplit(ht_array[tup_arb], sp)
np_array = ht_array[tup_arb].numpy()

ht_diff = ht.diff(lp_array, n=nl, axis=ax)
np_diff = ht.array(np.diff(np_array, n=nl, axis=ax))

Expand All @@ -280,10 +287,11 @@ def test_diff(self):
ht_append = ht.ones(
append_shape, dtype=lp_array.dtype, split=lp_array.split
)

ht_diff_pend = ht.diff(lp_array, n=nl, axis=ax, prepend=0, append=ht_append)
np_append = np.ones(append_shape, dtype=lp_array.larray.numpy().dtype)
np_diff_pend = ht.array(
np.diff(np_array, n=nl, axis=ax, prepend=0, append=ht_append.numpy()),
dtype=ht_diff_pend.dtype,
np.diff(np_array, n=nl, axis=ax, prepend=0, append=np_append)
)
self.assertTrue(ht.equal(ht_diff_pend, np_diff_pend))
self.assertEqual(ht_diff_pend.split, sp)
Expand Down Expand Up @@ -333,7 +341,7 @@ def test_div(self):
with self.assertRaises(ValueError):
ht.div(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.div(self.a_tensor, self.errorneous_type)
ht.div(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.div("T", "s")

Expand Down Expand Up @@ -362,7 +370,7 @@ def test_fmod(self):
with self.assertRaises(ValueError):
ht.fmod(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.fmod(self.a_tensor, self.errorneous_type)
ht.fmod(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.fmod("T", "s")

Expand Down Expand Up @@ -419,7 +427,7 @@ def test_mul(self):
with self.assertRaises(ValueError):
ht.mul(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.mul(self.a_tensor, self.errorneous_type)
ht.mul(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.mul("T", "s")

Expand Down Expand Up @@ -464,7 +472,7 @@ def test_pow(self):
with self.assertRaises(ValueError):
ht.pow(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.pow(self.a_tensor, self.errorneous_type)
ht.pow(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.pow("T", "s")

Expand Down Expand Up @@ -601,7 +609,7 @@ def test_sub(self):
with self.assertRaises(ValueError):
ht.sub(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.sub(self.a_tensor, self.errorneous_type)
ht.sub(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.sub("T", "s")

Expand Down