Skip to content

Commit

Permalink
Update np.correlate()
Browse files Browse the repository at this point in the history
  • Loading branch information
arshajii committed Feb 4, 2025
1 parent 44c59c2 commit 4521182
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 63 deletions.
216 changes: 153 additions & 63 deletions stdlib/numpy/statistics.codon
Original file line number Diff line number Diff line change
Expand Up @@ -270,85 +270,175 @@ def corrcoef(x, y=None, rowvar=True, dtype: type = NoneType):

return c

def correlate(a, b, mode: Static[str] = 'valid'):
a = asarray(a)
b = asarray(b)
def _correlate(a, b, mode: str):

def kernel(d, dstride: int, nd: int, dtype: type,
k, kstride: int, nk: Static[int], ktype: type,
out, ostride: int):
for i in range(nd):
acc = util.zero(dtype)
for j in staticrange(nk):
acc += d[(i + j) * dstride] * k[j * kstride]
out[i * ostride] = acc

def small_correlate(d, dstride: int, nd: int, dtype: type,
k, kstride: int, nk: int, ktype: type,
out, ostride: int):
if dtype is not ktype:
return False

dstride //= util.sizeof(dtype)
kstride //= util.sizeof(dtype)
ostride //= util.sizeof(dtype)

if nk == 1:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=1, ktype=ktype,
out=out, ostride=ostride)
elif nk == 2:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=2, ktype=ktype,
out=out, ostride=ostride)
elif nk == 3:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=3, ktype=ktype,
out=out, ostride=ostride)
elif nk == 4:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=4, ktype=ktype,
out=out, ostride=ostride)
elif nk == 5:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=5, ktype=ktype,
out=out, ostride=ostride)
elif nk == 6:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=6, ktype=ktype,
out=out, ostride=ostride)
elif nk == 7:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=7, ktype=ktype,
out=out, ostride=ostride)
elif nk == 8:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=8, ktype=ktype,
out=out, ostride=ostride)
elif nk == 9:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=9, ktype=ktype,
out=out, ostride=ostride)
elif nk == 10:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=10, ktype=ktype,
out=out, ostride=ostride)
elif nk == 11:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=11, ktype=ktype,
out=out, ostride=ostride)
else:
return False

if a.ndim != 1 or b.ndim != 1:
compile_error('object too deep for desired array')
return True

n1 = len(a)
n2 = len(b)
if n1 < n2:
inverted = 1
inv = n1
n1 = n2
n2 = inv
correlate(b, a, mode)
else:
inverted = 0
def dot(_ip1: Ptr[T1], is1: int, _ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int,
T1: type, T2: type, T3: type):
ip1 = _ip1.as_byte()
ip2 = _ip2.as_byte()
ans = util.zero(T3)

for i in range(n):
e1 = Ptr[T1](ip1)[0]
e2 = Ptr[T2](ip2)[0]
ans += util.cast(e1, T3) * util.cast(e2, T3)
ip1 += is1
ip2 += is2

op[0] = ans

def incr(p: Ptr[T], s: int, T: type):
return Ptr[T](p.as_byte() + s)

n1 = a.size
n2 = b.size
length = n1
n = n2

if mode == 'valid':
length = length - n + 1
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
or b.dtype is complex64):
ret = zeros(length, dtype=complex)
else:
ret = empty(length)
for i in range(length):
for j in range(n):
if inverted == 0:
ret.data[i] += a._ptr(
(j + i, ))[0] * conjugate(b._ptr((j, ))[0])
else:
ret.data[i] += a._ptr((j, ))[0] * b._ptr((j + i, ))[0]
length = length = length - n + 1
n_left = 0
n_right = 0
elif mode == 'same':
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
or b.dtype is complex64):
ret = zeros(length, dtype=complex)
else:
ret = empty(length)
for i in range(length):
for j in range(n):
signal_index = i - int(n / 2) + j
if signal_index >= 0 and signal_index < length:
if inverted == 0:
ret.data[i] += a._ptr(
(signal_index, ))[0] * conjugate(b._ptr((j, ))[0])
else:
ret.data[i] += a._ptr((j, ))[0] * b._ptr(
(signal_index, ))[0]
n_left = n >> 1
n_right = n - n_left - 1
elif mode == 'full':
full_length = length + n - 1
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
or b.dtype is complex64):
ret = zeros(full_length, dtype=complex)
else:
ret = empty(full_length)
for i in range(full_length):
for j in range(n):
signal_index = i + j - 2
if signal_index >= 0 and signal_index < length:
if inverted == 0:
ret.data[i] += a._ptr(
(signal_index, ))[0] * conjugate(b._ptr((j, ))[0])
else:
ret.data[i] += a._ptr((j, ))[0] * b._ptr(
(signal_index, ))[0]
n_right = n - 1
n_left = n - 1
length = length + n - 1
else:
raise ValueError(
f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})"
)

if inverted:
ret = ret[::-1]
if ret.dtype is complex or ret.dtype is complex64:
ret.map(conjugate, inplace=True)
dt = type(util.coerce(a.dtype, b.dtype))
ret = empty(length, dtype=dt)

is1 = a.strides[0]
is2 = b.strides[0]
op = ret.data
os = ret.itemsize
ip1 = a.data
ip2 = Ptr[b.dtype](b.data.as_byte() + n_left * is2)
n = n - n_left

for i in range(n_left):
dot(ip1, is1, ip2, is2, op, n)
n += 1
ip2 = incr(ip2, -is2)
op = incr(op, os)

if small_correlate(ip1, is1, n1 - n2 + 1, a.dtype,
ip2, is2, n, b.dtype,
op, os):
ip1 = incr(ip1, is1 * (n1 - n2 + 1))
op = incr(op, os * (n1 - n2 + 1))
else:
for i in range(n1 - n2 + 1):
dot(ip1, is1, ip2, is2, op, n)
ip1 = incr(ip1, is1)
op = incr(op, os)

for i in range(n_right):
n -= 1
dot(ip1, is1, ip2, is2, op, n)
ip1 = incr(ip1, is1)
op = incr(op, os)

return ret

def correlate(a, b, mode: str = 'valid'):
a = asarray(a)
b = asarray(b)

if a.ndim != 1 or b.ndim != 1:
compile_error('object too deep for desired array')

n1 = a.size
n2 = b.size

if n1 == 0:
raise ValueError("first argument cannot be empty")

if n2 == 0:
raise ValueError("second argument cannot be empty")

if b.dtype is complex or b.dtype is complex64:
b = b.conjugate()

if n1 < n2:
return _correlate(b, a, mode=mode)[::-1]
else:
return _correlate(a, b, mode=mode)

def bincount(x, weights=None, minlength: int = 0):
x = asarray(x).astype(int)

Expand Down
61 changes: 61 additions & 0 deletions test/numpy/test_statistics.codon
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,67 @@ test_correlate(np.array([1 + 0j, 2 + 0j, 3 + 0j, 4 + 1j]),
np.array([-1 + 0j, -2j, 3 + 1j]),
np.array([8. + 1.j, 11. + 5.j]))

@test
def test_correlate2():
# Integer inputs
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
assert np.allclose(np.correlate(a, b, mode="valid"), [32])
assert np.allclose(np.correlate(a, b, mode="same"), [17, 32, 23])
assert np.allclose(np.correlate(a, b, mode="full"), [6, 17, 32, 23, 12])

# Floating-point inputs
a = np.array([1.5, 2.5, 3.5])
b = np.array([4.0, 5.0, 6.0])
assert np.allclose(np.correlate(a, b, mode="valid"), [39.5])
assert np.allclose(np.correlate(a, b, mode="same"), [22.5, 39.5, 27.5])
assert np.allclose(np.correlate(a, b, mode="full"), [9.0, 22.5, 39.5, 27.5, 14.0])

# Complex numbers
a = np.array([1+2j, 3+4j])
b = np.array([5+6j, 7+8j])
assert np.allclose(np.correlate(a, b, mode="valid"), [70+8j])
assert np.allclose(np.correlate(a, b, mode="same"), [23+6j, 70+8j])
assert np.allclose(np.correlate(a, b, mode="full"), [23+6j, 70+8j, 39+2j])

# Different-length arrays
a = np.array([1, 2, 3, 4])
b = np.array([0, 1])
assert np.allclose(np.correlate(a, b, mode="valid"), [2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="same"), [1, 2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="full"), [1, 2, 3, 4, 0])
a = np.array([0, 1])
b = np.array([1, 2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="valid")[::-1], [2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="same")[::-1], [1, 2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="full")[::-1], [1, 2, 3, 4, 0])

# Large array test
a = np.arange(20)
b = np.arange(10)
expected_valid = np.array([np.sum(a[i : i + len(b)] * b) for i in range(len(a) - len(b) + 1)])
expected_full = np.correlate(a, b, mode="full")
expected_same = np.correlate(a, b, mode="same")
assert np.allclose(np.correlate(a, b, mode="valid"), expected_valid)
assert np.allclose(np.correlate(a, b, mode="full"), expected_full)
assert np.allclose(np.correlate(a, b, mode="same"), expected_same)

# Different dtypes (int and float)
a = np.array([1, 2, 3], dtype=int)
b = np.array([1.5, 2.5, 3.5], dtype=float)
assert np.allclose(np.correlate(a, b, mode="valid"), [17.0])
assert np.allclose(np.correlate(a, b, mode="same"), [9.5, 17.0, 10.5])
assert np.allclose(np.correlate(a, b, mode="full"), [3.5, 9.5, 17.0, 10.5, 4.5])

# Edge case: Single-element arrays
a = np.array([5])
b = np.array([10])
assert np.allclose(np.correlate(a, b, mode="valid"), [50])
assert np.allclose(np.correlate(a, b, mode="same"), [50])
assert np.allclose(np.correlate(a, b, mode="full"), [50])

test_correlate2()

@test
def test_bincount(x, expected, weights=None, minlength=0):
assert (np.bincount(x, weights=weights,
Expand Down

0 comments on commit 4521182

Please sign in to comment.