Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement vecdot #840

Merged
merged 9 commits into from
Aug 2, 2021
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension

# Feature Additions
## Feature Additions

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
## Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`

Expand Down
55 changes: 54 additions & 1 deletion heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@
from .. import sanitation
from .. import types

__all__ = ["dot", "matmul", "norm", "outer", "projection", "trace", "transpose", "tril", "triu"]
__all__ = [
"dot",
"matmul",
"norm",
"outer",
"projection",
"trace",
"transpose",
"tril",
"triu",
"vecdot",
]


def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDarray, float]:
Expand All @@ -39,6 +50,11 @@ def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDar
Second input DNDarray
out : DNDarray, optional
Output buffer.

See Also
--------
vecdot
Supports (vector) dot along an axis.
"""
if isinstance(a, (float, int)) or isinstance(b, (float, int)) or a.ndim == 0 or b.ndim == 0:
# 3. If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred.
Expand Down Expand Up @@ -1638,3 +1654,40 @@ def triu(m: DNDarray, k: int = 0) -> DNDarray:

DNDarray.triu: Callable[[DNDarray, int], DNDarray] = lambda self, k=0: triu(self, k)
DNDarray.triu.__doc__ = triu.__doc__


def vecdot(
x1: DNDarray, x2: DNDarray, axis: Optional[int] = None, keepdim: Optional[bool] = None
) -> DNDarray:
"""
Computes the (vector) dot product of two DNDarrays.

Parameters
----------
x1 : DNDarray
first input array.
x2 : DNDarray
second input array. Must be compatible with x1.
axis : int, optional
axis over which to compute the dot product. The last dimension is used if 'None'.
keepdim : bool, optional
If this is set to 'True', the axes which are reduced are left in the result as dimensions with size one.

See Also
--------
dot
NumPy-like dot function.

Examples
--------
>>> ht.vecdot(ht.full((3,3,3),3), ht.ones((3,3)), axis=0)
DNDarray([[9., 9., 9.],
[9., 9., 9.],
[9., 9., 9.]], dtype=ht.float32, device=cpu:0, split=None)
"""
m = arithmetics.mul(x1, x2)

if axis is None:
axis = m.ndim - 1

return arithmetics.sum(m, axis=axis, keepdim=keepdim)
18 changes: 18 additions & 0 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,3 +1672,21 @@ def test_triu(self):
self.assertTrue(result.larray[-1, 0] == 0)
if result.comm.rank == result.shape[0] - 1:
self.assertTrue(result.larray[0, -1] == 1)

def test_vecdot(self):
a = ht.array([1, 1, 1])
b = ht.array([1, 2, 3])

c = ht.linalg.vecdot(a, b)

self.assertEqual(c.dtype, ht.int64)
self.assertEqual(c.device, a.device)
self.assertTrue(ht.equal(c, ht.array([6])))

a = ht.full((4, 4), 2, split=0)
b = ht.ones(4)

c = ht.linalg.vecdot(a, b, axis=0, keepdim=True)
self.assertEqual(c.dtype, ht.float32)
self.assertEqual(c.device, a.device)
self.assertTrue(ht.equal(c, ht.array([[8, 8, 8, 8]])))