Skip to content

[ENH] Parallelize SAX and PAA transformers #2980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

aadya940
Copy link
Contributor

Related to #2972

@aeon-actions-bot aeon-actions-bot bot added enhancement New feature, improvement request or other non-bug code enhancement transformations Transformations package labels Jul 31, 2025
@aeon-actions-bot
Copy link
Contributor

Thank you for contributing to aeon

I have added the following labels to this PR based on the title: [ enhancement ].
I have added the following labels to this PR based on the changes made: [ transformations ]. Feel free to change these if they do not properly represent the PR.

The Checks tab will show the status of our automated tests. You can click on individual test runs in the tab or "Details" in the panel below to see more information if there is a failure.

If our pre-commit code quality check fails, any trivial fixes will automatically be pushed to your PR unless it is a draft.

Don't hesitate to ask questions on the aeon Slack channel if you have any.

PR CI actions

These checkboxes will add labels to enable/disable CI functionality for this PR. This may not take effect immediately, and a new commit may be required to run the new configuration.

  • Run pre-commit checks for all files
  • Run mypy typecheck tests
  • Run all pytest tests and configurations
  • Run all notebook example tests
  • Run numba-disabled codecov tests
  • Stop automatic pre-commit fixes (always disabled for drafts)
  • Disable numba cache loading
  • Push an empty commit to re-run CI checks

@hadifawaz1999
Copy link
Member

Nice addition @aadya940 !

Could you run examples comparing the old and new code in terms of output accuracy ? to make sure the algorithm didnt change
Also can we benchmark the time gain ? like see how faster it becomes as plot in function of n samples length of the series etc.

@aadya940
Copy link
Contributor Author

@hadifawaz1999 Yes sure

Copy link
Contributor

@patrickzib patrickzib left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your additions. Some comments of mine.

return X_paa


@njit(parallel=True, fastmath=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add cache=True

@@ -143,3 +146,45 @@ def _get_test_params(cls, parameter_set="default"):
"""
params = {"n_segments": 10}
return params


@njit(parallel=True, fastmath=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add cache=True

for i in range(n_samples):
for j in range(n_channels):
acc = 0.0
for k in range(seg_len):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the .mean() here as in the original version??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numba doesn't have an implementation for mean

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, I understand. Both seem to work fine (I might be missing something)

        for i in range(n_samples):
            for j in range(n_channels):
                acc = X[i, j, segment].mean()
        for i in range(n_samples):
            for j in range(n_channels):
                acc = np.mean(X[i, j, segment])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant we used axis=-1 in the original implementation, however, numba doesn't support optional arguments. Hence, implemented it this way.
:))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I understand..

Yet, why not use one of the two alternatives presented above that do not require axis=-1?

continue # skip empty segment

for i in range(n_samples):
for j in range(n_channels):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prange?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that spawn too many threads, given that the outer loop is with a prange?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have used three nested prange loops in _parallel_inverse_paa_transform ?

But, in fact numba seems to ignore nested loops. I did not know this before:

Loop serialization
Loop serialization occurs when any number of prange driven loops are present inside another prange driven loop. In this case the outermost of all the prange loops executes in parallel and any inner prange loops (nested or otherwise) are treated as standard range based loops. Essentially, nested parallelism does not occur.

https://numba.pydata.org/numba-doc/dev/user/parallel.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, missed it. However, I'd also like to point out, the previously implemented numba functions in SAX use multiple nested pranges as well which is essentially dead code. I can remove it in this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, who wrote these. @hadifawaz1999 ?

Copy link
Member

@hadifawaz1999 hadifawaz1999 Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i wrote these, i dont the issue though from the discussion, SAX was working fine why is it dead code ? @patrickzib

@hadifawaz1999
Copy link
Member

nothing is dead code @aadya940 :) just probably inconsistent usage of prange, but code works fine

@aadya940
Copy link
Contributor Author

aadya940 commented Jul 31, 2025

@hadifawaz1999 By dead code I meant it has no effect on the output and gives identical results in terms of speed and accuracy to if there was a single prange since numba ignores other nested prange :)

Copy link
Member

@hadifawaz1999 hadifawaz1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be compared in terms of accuracy before runtime, and should see how worth the gain is compared to the new amount of code

@@ -292,3 +296,31 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid):
]

return sax_inverse


@njit(fastmath=True, cache=True, parallel=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure what the gain would be of doing all this, if its not significant am not for doing this, it would mean instead of simply doing sax_symbols = np.digitize(x=X_paa, bins=self.breakpoints) we have 30+ new lines of code wiith nested loops and prange.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have done a benchmark in the past, I think it performs little better after n_jobs > 2 and significantly better post 4 threads as compared to np.digitize

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having dealt with MONSTER datasets lately, not being able to parallelize stuff for such huge size datasets is a big downside.
It might not be significant for UCR/UEA, but can help a lot when dataset size grow.

I'm not a fan of the flatten and reshape, though, but I guess you did it to avoid the nested loop parallelism problem ?

You could add that breakspoints need to be sorted for the function to work for clarity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to check _sfa_fast.py

I use np.digitize within a parallel for-loop. This approach strikes a balance between using only explicit loops and relying solely on vectorized digitize operations.

        for a in prange(dfts.shape[0]):
            for i in range(word_length):  # range(dfts.shape[2]):
                words[a, : dfts.shape[1]] = (
                    words[a, : dfts.shape[1]] << letter_bits
                ) | np.digitize(dfts[a, :, i], breakpoints[i], right=True)

@hadifawaz1999
Copy link
Member

@hadifawaz1999 By dead code I meant it has no effect on the output and gives identical results in terms of speed and accuracy to if there was a single prange since numba ignores other nested prange :)

if am not mistaken numba considers a prange as a normal range when its nested, so its simply useless usage of prange nested yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature, improvement request or other non-bug code enhancement transformations Transformations package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants