diff --git a/heat/core/signal.py b/heat/core/signal.py index b7556e50af..aab6b4d113 100644 --- a/heat/core/signal.py +++ b/heat/core/signal.py @@ -1,13 +1,13 @@ """Provides a collection of signal-processing operations""" import torch -from typing import Union, Tuple, Sequence +import numpy as np from .communication import MPI from .dndarray import DNDarray from .types import promote_types -from .manipulations import pad -from .factories import array +from .manipulations import pad, flip +from .factories import array, zeros import torch.nn.functional as fc __all__ = ["convolve"] @@ -15,14 +15,14 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: """ - Returns the discrete, linear convolution of two one-dimensional `DNDarray`s. + Returns the discrete, linear convolution of two one-dimensional `DNDarray`s or scalars. Parameters ---------- - a : DNDarray - One-dimensional signal `DNDarray` of shape (N,) - v : DNDarray - One-dimensional filter weight `DNDarray` of shape (M,). + a : DNDarray or scalar + One-dimensional signal `DNDarray` of shape (N,) or scalar. + v : DNDarray or scalar + One-dimensional filter weight `DNDarray` of shape (M,) or scalar. mode : str Can be 'full', 'valid', or 'same'. Default is 'full'. 'full': @@ -40,15 +40,6 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: overlap completely. Values outside the signal boundary have no effect. - Notes - ----- - Contrary to the original `numpy.convolve`, this function does not - swap the input arrays if the second one is larger than the first one. - This is because `a`, the signal, might be memory-distributed, - whereas the filter `v` is assumed to be non-distributed, - i.e. a copy of `v` will reside on each process. - - Examples -------- Note how the convolution operator flips the second array @@ -62,7 +53,27 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: DNDarray([1., 3., 3., 3., 3.]) >>> ht.convolve(a, v, mode='valid') DNDarray([3., 3., 3.]) + >>> a = ht.ones(10, split = 0) + >>> v = ht.arange(3, split = 0).astype(ht.float) + >>> ht.convolve(a, v, mode='valid') + DNDarray([3., 3., 3., 3., 3., 3., 3., 3.]) + + [0/3] DNDarray([3., 3., 3.]) + [1/3] DNDarray([3., 3., 3.]) + [2/3] DNDarray([3., 3.]) + >>> a = ht.ones(10, split = 0) + >>> v = ht.arange(3, split = 0) + >>> ht.convolve(a, v) + DNDarray([0., 1., 3., 3., 3., 3., 3., 3., 3., 3., 3., 2.], dtype=ht.float32, device=cpu:0, split=0) + + [0/3] DNDarray([0., 1., 3., 3.]) + [1/3] DNDarray([3., 3., 3., 3.]) + [2/3] DNDarray([3., 3., 3., 2.]) """ + if np.isscalar(a): + a = array([a]) + if np.isscalar(v): + v = array([v]) if not isinstance(a, DNDarray): try: a = array(a) @@ -77,24 +88,25 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: a = a.astype(promoted_type) v = v.astype(promoted_type) - if v.is_distributed(): - raise TypeError("Distributed filter weights are not supported") if len(a.shape) != 1 or len(v.shape) != 1: raise ValueError("Only 1-dimensional input DNDarrays are allowed") - if a.shape[0] <= v.shape[0]: - raise ValueError("Filter size must not be greater than or equal to signal size") if mode == "same" and v.shape[0] % 2 == 0: raise ValueError("Mode 'same' cannot be used with even-sized kernel") + if not v.is_balanced(): + raise ValueError("Only balanced kernel weights are allowed") + + if v.shape[0] > a.shape[0]: + a, v = v, a # compute halo size - halo_size = v.shape[0] // 2 + halo_size = torch.max(v.lshape_map[:, 0]).item() // 2 # pad DNDarray with zeros according to mode if mode == "full": pad_size = v.shape[0] - 1 gshape = v.shape[0] + a.shape[0] - 1 elif mode == "same": - pad_size = halo_size + pad_size = v.shape[0] // 2 gshape = a.shape[0] elif mode == "valid": pad_size = 0 @@ -105,8 +117,10 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: a = pad(a, pad_size, "constant", 0) if a.is_distributed(): - if (v.shape[0] > a.lshape_map[:, 0]).any(): - raise ValueError("Filter weight is larger than the local chunks of signal") + if (v.lshape_map[:, 0] > a.lshape_map[:, 0]).any(): + raise ValueError( + "Local chunk of filter weight is larger than the local chunks of signal" + ) # fetch halos and store them in a.halo_next/a.halo_prev a.get_halo(halo_size) # apply halos to local array @@ -114,11 +128,21 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: else: signal = a.larray + # flip filter for convolution as Pytorch conv1d computes correlations + v = flip(v, [0]) + if v.larray.shape != v.lshape_map[0]: + # pads weights if input kernel is uneven + target = torch.zeros(v.lshape_map[0][0], dtype=v.larray.dtype, device=v.larray.device) + pad_size = v.lshape_map[0][0] - v.larray.shape[0] + target[pad_size:] = v.larray + weight = target + else: + weight = v.larray + + t_v = weight # stores temporary weight + # make signal and filter weight 3D for Pytorch conv1d function signal = signal.reshape(1, 1, signal.shape[0]) - - # flip filter for convolution as Pytorch conv1d computes correlations - weight = v.larray.flip(dims=(0,)) weight = weight.reshape(1, 1, weight.shape[0]) # cast to float if on GPU @@ -126,23 +150,56 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: float_type = promote_types(signal.dtype, torch.float32).torch_type() signal = signal.to(float_type) weight = weight.to(float_type) + t_v = t_v.to(float_type) - # apply torch convolution operator - signal_filtered = fc.conv1d(signal, weight) - - # unpack 3D result into 1D - signal_filtered = signal_filtered[0, 0, :] - - # if kernel shape along split axis is even we need to get rid of duplicated values - if a.comm.rank != 0 and v.shape[0] % 2 == 0: - signal_filtered = signal_filtered[1:] - - return DNDarray( - signal_filtered.contiguous(), - (gshape,), - signal_filtered.dtype, - a.split, - a.device, - a.comm, - balanced=False, - ).astype(a.dtype.torch_type()) + if v.is_distributed(): + size = v.comm.size + + for r in range(size): + rec_v = v.comm.bcast(t_v, root=r) + t_v1 = rec_v.reshape(1, 1, rec_v.shape[0]) + local_signal_filtered = fc.conv1d(signal, t_v1) + # unpack 3D result into 1D + local_signal_filtered = local_signal_filtered[0, 0, :] + + if a.comm.rank != 0 and v.lshape_map[0][0] % 2 == 0: + local_signal_filtered = local_signal_filtered[1:] + + # accumulate filtered signal on the fly + global_signal_filtered = array( + local_signal_filtered, is_split=0, device=a.device, comm=a.comm + ) + if r == 0: + # initialize signal_filtered, starting point of slice + signal_filtered = zeros( + gshape, dtype=a.dtype, split=a.split, device=a.device, comm=a.comm + ) + start_idx = 0 + + # accumulate relevant slice of filtered signal + # note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op() + signal_filtered += global_signal_filtered[start_idx : start_idx + gshape] + if r != size - 1: + start_idx += v.lshape_map[r + 1][0].item() + return signal_filtered + + else: + # apply torch convolution operator + signal_filtered = fc.conv1d(signal, weight) + + # unpack 3D result into 1D + signal_filtered = signal_filtered[0, 0, :] + + # if kernel shape along split axis is even we need to get rid of duplicated values + if a.comm.rank != 0 and v.shape[0] % 2 == 0: + signal_filtered = signal_filtered[1:] + + return DNDarray( + signal_filtered.contiguous(), + (gshape,), + signal_filtered.dtype, + a.split, + a.device, + a.comm, + balanced=False, + ).astype(a.dtype.torch_type()) diff --git a/heat/core/tests/test_signal.py b/heat/core/tests/test_signal.py index a471218f89..7abd69e183 100644 --- a/heat/core/tests/test_signal.py +++ b/heat/core/tests/test_signal.py @@ -20,59 +20,94 @@ def test_convolve(self): [0, 1, 3, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 42, 29, 15] ).astype(ht.int) - signal = ht.arange(0, 16, split=0).astype(ht.int) + dis_signal = ht.arange(0, 16, split=0).astype(ht.int) + signal = ht.arange(0, 16).astype(ht.int) + full_ones = ht.ones(7, split=0).astype(ht.int) kernel_odd = ht.ones(3).astype(ht.int) kernel_even = [1, 1, 1, 1] + dis_kernel_odd = ht.ones(3, split=0).astype(ht.int) + dis_kernel_even = ht.ones(4, split=0).astype(ht.int) with self.assertRaises(TypeError): signal_wrong_type = [0, 1, 2, "tre", 4, "five", 6, "ʻehiku", 8, 9, 10] ht.convolve(signal_wrong_type, kernel_odd, mode="full") with self.assertRaises(TypeError): filter_wrong_type = [1, 1, "pizza", "pineapple"] - ht.convolve(signal, filter_wrong_type, mode="full") + ht.convolve(dis_signal, filter_wrong_type, mode="full") with self.assertRaises(ValueError): - ht.convolve(signal, kernel_odd, mode="invalid") + ht.convolve(dis_signal, kernel_odd, mode="invalid") with self.assertRaises(ValueError): - s = signal.reshape((2, -1)) + s = dis_signal.reshape((2, -1)) ht.convolve(s, kernel_odd) with self.assertRaises(ValueError): k = ht.eye(3) - ht.convolve(signal, k) - with self.assertRaises(ValueError): - ht.convolve(kernel_even, full_even) + ht.convolve(dis_signal, k) with self.assertRaises(ValueError): - ht.convolve(signal, kernel_even, mode="same") + ht.convolve(dis_signal, kernel_even, mode="same") if self.comm.size > 1: - with self.assertRaises(TypeError): - k = ht.ones(4, split=0).astype(ht.int) - ht.convolve(signal, k) - if self.comm.size >= 5: with self.assertRaises(ValueError): - ht.convolve(signal, kernel_even, mode="valid") + ht.convolve(full_ones, kernel_even, mode="valid") + with self.assertRaises(ValueError): + ht.convolve(kernel_even, full_ones, mode="valid") + if self.comm.size > 5: + with self.assertRaises(ValueError): + ht.convolve(dis_signal, kernel_even) # test modes, avoid kernel larger than signal chunk if self.comm.size <= 3: modes = ["full", "same", "valid"] for i, mode in enumerate(modes): # odd kernel size - conv = ht.convolve(signal, kernel_odd, mode=mode) + conv = ht.convolve(dis_signal, kernel_odd, mode=mode) + gathered = manipulations.resplit(conv, axis=None) + self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i], gathered)) + + conv = ht.convolve(dis_signal, dis_kernel_odd, mode=mode) + gathered = manipulations.resplit(conv, axis=None) + self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i], gathered)) + + conv = ht.convolve(signal, dis_kernel_odd, mode=mode) gathered = manipulations.resplit(conv, axis=None) self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i], gathered)) + # different data types - conv = ht.convolve(signal.astype(ht.float), kernel_odd) + conv = ht.convolve(dis_signal.astype(ht.float), kernel_odd) + gathered = manipulations.resplit(conv, axis=None) + self.assertTrue(ht.equal(full_odd.astype(ht.float), gathered)) + + conv = ht.convolve(dis_signal.astype(ht.float), dis_kernel_odd) + gathered = manipulations.resplit(conv, axis=None) + self.assertTrue(ht.equal(full_odd.astype(ht.float), gathered)) + + conv = ht.convolve(signal.astype(ht.float), dis_kernel_odd) gathered = manipulations.resplit(conv, axis=None) self.assertTrue(ht.equal(full_odd.astype(ht.float), gathered)) # even kernel size # skip mode 'same' for even kernels if mode != "same": - conv = ht.convolve(signal, kernel_even, mode=mode) + conv = ht.convolve(dis_signal, kernel_even, mode=mode) + dis_conv = ht.convolve(dis_signal, dis_kernel_even, mode=mode) gathered = manipulations.resplit(conv, axis=None) + dis_gathered = manipulations.resplit(dis_conv, axis=None) if mode == "full": self.assertTrue(ht.equal(full_even, gathered)) + self.assertTrue(ht.equal(full_even, dis_gathered)) else: self.assertTrue(ht.equal(full_even[3:-3], gathered)) + self.assertTrue(ht.equal(full_even[3:-3], dis_gathered)) + + # distributed large signal and kernel + np.random.seed(12) + np_a = np.random.randint(1000, size=4418) + np_b = np.random.randint(1000, size=1543) + np_conv = np.convolve(np_a, np_b, mode=mode) + + a = ht.array(np_a, split=0, dtype=ht.int32) + b = ht.array(np_b, split=0, dtype=ht.int32) + conv = ht.convolve(a, b, mode=mode) + self.assert_array_equal(conv, np_conv) # test edge cases # non-distributed signal, size-1 kernel @@ -81,3 +116,6 @@ def test_convolve(self): kernel = ht.ones(1).astype(ht.int) conv = ht.convolve(alt_signal, kernel) self.assertTrue(ht.equal(signal, conv)) + + conv = ht.convolve(1, 5) + self.assertTrue(ht.equal(ht.array([5]), conv))