Skip to content

Commit

Permalink
Merge pull request #465 from helmholtz-analytics/bug/464-unique-torch…
Browse files Browse the repository at this point in the history
…-tensor

Bug/464 unique torch tensor
  • Loading branch information
ClaudiaComito authored Jan 24, 2020
2 parents 075c138 + 5ee0c04 commit 5c3801a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This version adds support for PyTorch 1.4.0. There are also several minor featur
- added option for neutral elements to be used in the place of empty tensors in reduction operations (`operations.__reduce_op`) (cf. [#369](https://github.com/helmholtz-analytics/heat/issues/369) and [#444](https://github.com/helmholtz-analytics/heat/issues/444))
- `var` and `std` both now support iterable axis arguments
- updated pull request template
- bug fix: `x.unique()` returns a DNDarray both in distributed and non-distributed mode (cf. [#464])

# v0.2.1

Expand Down
10 changes: 8 additions & 2 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,10 +1036,16 @@ def unique(a, sorted=False, return_inverse=False, axis=None):
[3, 1]])
"""
if a.split is None:
# Trivial case, result can just be forwarded
return torch.unique(
torch_output = torch.unique(
a._DNDarray__array, sorted=sorted, return_inverse=return_inverse, dim=axis
)
if isinstance(torch_output, tuple):
heat_output = tuple(
factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output
)
else:
heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device)
return heat_output

local_data = a._DNDarray__array
unique_axis = None
Expand Down
11 changes: 10 additions & 1 deletion heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,8 +1069,17 @@ def test_unique(self):
exp_res, exp_inv = torch_array.unique(return_inverse=True, sorted=True)

data_split_none = ht.array(torch_array, device=ht_device)
res = ht.unique(data_split_none, sorted=True)
self.assertIsInstance(res, ht.DNDarray)
self.assertEqual(res.split, None)
self.assertEqual(res.dtype, data_split_none.dtype)
self.assertEqual(res.device, data_split_none.device)
res, inv = ht.unique(data_split_none, return_inverse=True, sorted=True)
self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype)))
self.assertIsInstance(inv, ht.DNDarray)
self.assertEqual(inv.split, None)
self.assertEqual(inv.dtype, data_split_none.dtype)
self.assertEqual(inv.device, data_split_none.device)
self.assertTrue(torch.equal(inv._DNDarray__array, exp_inv))

data_split_zero = ht.array(torch_array, split=0, device=ht_device)
res, inv = ht.unique(data_split_zero, return_inverse=True, sorted=True)
Expand Down

0 comments on commit 5c3801a

Please sign in to comment.