Skip to content

Commit

Permalink
Merge pull request #862 from helmholtz-analytics/feature/861-signbit
Browse files Browse the repository at this point in the history
add signbit
  • Loading branch information
coquelin77 authored Sep 16, 2021
2 parents 631f113 + 37eae3d commit 9db0a93
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
### Logical
- [#862](https://github.com/helmholtz-analytics/heat/pull/862) New feature `signbit`
### Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
Expand Down
21 changes: 21 additions & 0 deletions heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"logical_not",
"logical_or",
"logical_xor",
"signbit",
]


Expand Down Expand Up @@ -508,3 +509,23 @@ def sanitize_input_type(

else:
return x, y


def signbit(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
"""
Checks if signbit is set element-wise (less than zero).
Parameters
----------
x : DNDarray
The input array.
out : DNDarray, optional
The output array.
Examples
--------
>>> a = ht.array([2, -1.3, 0])
>>> ht.signbit(a)
DNDarray([False, True, False], dtype=ht.bool, device=cpu:0, split=None)
"""
return _operations.__local_op(torch.signbit, x, out, no_cast=True)
11 changes: 11 additions & 0 deletions heat/core/tests/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,14 @@ def test_logical_xor(self):
ht.array([[False, False], [False, False]]),
)
)

def test_signbit(self):
a = ht.array([2, -1.3, 0, -5], split=0)

sb = ht.signbit(a)
cmp = ht.array([False, True, False, True])

self.assertEqual(sb.dtype, ht.bool)
self.assertEqual(sb.split, 0)
self.assertEqual(sb.device, a.device)
self.assertTrue(ht.equal(sb, cmp))

0 comments on commit 9db0a93

Please sign in to comment.