-
Notifications
You must be signed in to change notification settings - Fork 215
[ENH] Parallelize SAX and PAA transformers #2980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for contributing to
|
Nice addition @aadya940 ! Could you run examples comparing the old and new code in terms of output accuracy ? to make sure the algorithm didnt change |
@hadifawaz1999 Yes sure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your additions. Some comments of mine.
return X_paa | ||
|
||
|
||
@njit(parallel=True, fastmath=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add cache=True
@@ -143,3 +146,45 @@ def _get_test_params(cls, parameter_set="default"): | |||
""" | |||
params = {"n_segments": 10} | |||
return params | |||
|
|||
|
|||
@njit(parallel=True, fastmath=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add cache=True
for i in range(n_samples): | ||
for j in range(n_channels): | ||
acc = 0.0 | ||
for k in range(seg_len): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use the .mean()
here as in the original version??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numba doesn't have an implementation for mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, I understand. Both seem to work fine (I might be missing something)
for i in range(n_samples):
for j in range(n_channels):
acc = X[i, j, segment].mean()
for i in range(n_samples):
for j in range(n_channels):
acc = np.mean(X[i, j, segment])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant we used axis=-1
in the original implementation, however, numba doesn't support optional arguments. Hence, implemented it this way.
:))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I understand..
Yet, why not use one of the two alternatives presented above that do not require axis=-1
?
continue # skip empty segment | ||
|
||
for i in range(n_samples): | ||
for j in range(n_channels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prange
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't that spawn too many threads, given that the outer loop is with a prange
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have used three nested prange loops in _parallel_inverse_paa_transform
?
But, in fact numba seems to ignore nested loops. I did not know this before:
Loop serialization
Loop serialization occurs when any number of prange driven loops are present inside another prange driven loop. In this case the outermost of all the prange loops executes in parallel and any inner prange loops (nested or otherwise) are treated as standard range based loops. Essentially, nested parallelism does not occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, missed it. However, I'd also like to point out, the previously implemented numba functions in SAX use multiple nested pranges
as well which is essentially dead code. I can remove it in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, who wrote these. @hadifawaz1999 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes i wrote these, i dont the issue though from the discussion, SAX was working fine why is it dead code ? @patrickzib
nothing is dead code @aadya940 :) just probably inconsistent usage of prange, but code works fine |
@hadifawaz1999 By dead code I meant it has no effect on the output and gives identical results in terms of speed and accuracy to if there was a single |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be compared in terms of accuracy before runtime, and should see how worth the gain is compared to the new amount of code
@@ -292,3 +296,31 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid): | |||
] | |||
|
|||
return sax_inverse | |||
|
|||
|
|||
@njit(fastmath=True, cache=True, parallel=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not 100% sure what the gain would be of doing all this, if its not significant am not for doing this, it would mean instead of simply doing sax_symbols = np.digitize(x=X_paa, bins=self.breakpoints)
we have 30+ new lines of code wiith nested loops and prange.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have done a benchmark in the past, I think it performs little better after n_jobs > 2
and significantly better post 4 threads as compared to np.digitize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having dealt with MONSTER datasets lately, not being able to parallelize stuff for such huge size datasets is a big downside.
It might not be significant for UCR/UEA, but can help a lot when dataset size grow.
I'm not a fan of the flatten and reshape, though, but I guess you did it to avoid the nested loop parallelism problem ?
You could add that breakspoints need to be sorted for the function to work for clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to check _sfa_fast.py
I use np.digitize within a parallel for-loop. This approach strikes a balance between using only explicit loops and relying solely on vectorized digitize operations.
for a in prange(dfts.shape[0]):
for i in range(word_length): # range(dfts.shape[2]):
words[a, : dfts.shape[1]] = (
words[a, : dfts.shape[1]] << letter_bits
) | np.digitize(dfts[a, :, i], breakpoints[i], right=True)
if am not mistaken numba considers a prange as a normal range when its nested, so its simply useless usage of prange nested yes |
Related to #2972