Skip to content

Commit cea39a5

Browse files
committed
bug fix
1 parent f1b04ce commit cea39a5

File tree

1 file changed

+25
-14
lines changed
  • aeon/transformations/collection/dictionary_based

1 file changed

+25
-14
lines changed

aeon/transformations/collection/dictionary_based/_sax.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,28 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid):
299299

300300

301301
@njit(fastmath=True, cache=True, parallel=True)
302-
def _parallel_get_sax_symbols(X, breakpoints):
303-
n_cases, n_channels, n_timepoints = X.shape
304-
X_new = np.zeros((n_cases, n_channels, n_timepoints), dtype=np.intp)
305-
n_break = breakpoints.shape[0] - 1
306-
for i_x in prange(n_cases):
307-
for i_c in prange(n_channels):
308-
for i_b in prange(n_break):
309-
mask = np.where(
310-
(X[i_x, i_c] >= breakpoints[i_b])
311-
& (X[i_x, i_c] < breakpoints[i_b + 1])
312-
)[0]
313-
X_new[i_x, i_c, mask] += np.array(i_b).astype(np.intp)
314-
315-
return X_new
302+
def _parallel_get_sax_symbols(x, bins, right=False):
303+
"""Parallel version of `np.digitize`."""
304+
x_flat = x.flatten()
305+
result = np.empty(x_flat.shape[0], dtype=np.intp)
306+
307+
for i in prange(x_flat.shape[0]):
308+
val = x_flat[i]
309+
bin_idx = 0
310+
311+
if right:
312+
for j in range(len(bins)):
313+
if val <= bins[j]:
314+
bin_idx = j
315+
break
316+
bin_idx = j + 1
317+
else:
318+
for j in range(len(bins)):
319+
if val < bins[j]:
320+
bin_idx = j
321+
break
322+
bin_idx = j + 1
323+
324+
result[i] = bin_idx
325+
326+
return result.reshape(x.shape)

0 commit comments

Comments
 (0)