@@ -299,17 +299,28 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid):
299
299
300
300
301
301
@njit (fastmath = True , cache = True , parallel = True )
302
- def _parallel_get_sax_symbols (X , breakpoints ):
303
- n_cases , n_channels , n_timepoints = X .shape
304
- X_new = np .zeros ((n_cases , n_channels , n_timepoints ), dtype = np .intp )
305
- n_break = breakpoints .shape [0 ] - 1
306
- for i_x in prange (n_cases ):
307
- for i_c in prange (n_channels ):
308
- for i_b in prange (n_break ):
309
- mask = np .where (
310
- (X [i_x , i_c ] >= breakpoints [i_b ])
311
- & (X [i_x , i_c ] < breakpoints [i_b + 1 ])
312
- )[0 ]
313
- X_new [i_x , i_c , mask ] += np .array (i_b ).astype (np .intp )
314
-
315
- return X_new
302
+ def _parallel_get_sax_symbols (x , bins , right = False ):
303
+ """Parallel version of `np.digitize`."""
304
+ x_flat = x .flatten ()
305
+ result = np .empty (x_flat .shape [0 ], dtype = np .intp )
306
+
307
+ for i in prange (x_flat .shape [0 ]):
308
+ val = x_flat [i ]
309
+ bin_idx = 0
310
+
311
+ if right :
312
+ for j in range (len (bins )):
313
+ if val <= bins [j ]:
314
+ bin_idx = j
315
+ break
316
+ bin_idx = j + 1
317
+ else :
318
+ for j in range (len (bins )):
319
+ if val < bins [j ]:
320
+ bin_idx = j
321
+ break
322
+ bin_idx = j + 1
323
+
324
+ result [i ] = bin_idx
325
+
326
+ return result .reshape (x .shape )
0 commit comments