Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fasthmath precision issue #1048

Merged
merged 69 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
0d4abf0
removed reassoc flag from fastmath
NimaSarajpoor Dec 5, 2024
7b9a2a3
Add reset feature to config
NimaSarajpoor Dec 22, 2024
95e7d8f
Revised config value
NimaSarajpoor Dec 22, 2024
1f95f77
replaced fastmath flags with config var
NimaSarajpoor Dec 22, 2024
37add84
fixed format
NimaSarajpoor Dec 22, 2024
c795b9d
Removed bad f-string
NimaSarajpoor Dec 22, 2024
2774302
Replaced Raised with Returns in docstring
NimaSarajpoor Dec 22, 2024
80544c7
Add second attempt for assertion
NimaSarajpoor Dec 22, 2024
a080495
minor change
NimaSarajpoor Dec 22, 2024
976ec13
Add condition to avoid revising fastmath when JIT is disabled
NimaSarajpoor Dec 22, 2024
b24ff9d
Removed support for input with type list to simplify function
NimaSarajpoor Dec 23, 2024
d13a32b
Refactored the recompile process
NimaSarajpoor Dec 24, 2024
4256839
removed blank lines
NimaSarajpoor Dec 24, 2024
f17f92d
fixed typo
NimaSarajpoor Dec 26, 2024
e543379
replaced hardcoded fastmath value with config var
NimaSarajpoor Dec 26, 2024
71fe3aa
revised function
NimaSarajpoor Dec 26, 2024
6854b0e
renamed variable to improve readability
NimaSarajpoor Dec 26, 2024
aeea9b4
fixed bug
NimaSarajpoor Dec 26, 2024
b41cc66
rename config to improve readability
NimaSarajpoor Dec 28, 2024
02b1fb4
revise func clear
NimaSarajpoor Dec 28, 2024
9647113
revise func to recompile all njit functions
NimaSarajpoor Dec 28, 2024
245bfc4
Adapt to changes in test function
NimaSarajpoor Dec 28, 2024
7dd15eb
add test
NimaSarajpoor Dec 28, 2024
4b52ff5
resolve coverage
NimaSarajpoor Dec 28, 2024
aabd41a
resolve missing lines in coverage
NimaSarajpoor Dec 28, 2024
47c9932
Add test function to improve coverage
NimaSarajpoor Dec 28, 2024
816e8d4
add fastmath module
NimaSarajpoor Dec 29, 2024
562550e
revise test function to use fastmath module
NimaSarajpoor Dec 29, 2024
d99148a
fix minor issue
NimaSarajpoor Dec 29, 2024
8a45f6e
minor change to improve readability
NimaSarajpoor Dec 29, 2024
7eb52c1
Add fastmath default flags to config default
NimaSarajpoor Jan 4, 2025
b130fa7
add reset function
NimaSarajpoor Jan 4, 2025
6264e94
rename function
NimaSarajpoor Jan 4, 2025
1d873bd
adapt recent changes in test function
NimaSarajpoor Jan 4, 2025
127e61c
minor fixes
NimaSarajpoor Jan 4, 2025
447c006
Check if DISABLE_JIT before getting fastmath
NimaSarajpoor Jan 4, 2025
14c2267
ignore lines for coverage check
NimaSarajpoor Jan 4, 2025
ec960bd
Merge branch 'main' into investigate_precision_failure
NimaSarajpoor Jan 12, 2025
2469458
Editorial fix
NimaSarajpoor Jan 12, 2025
2624445
avoid .get(key) to get KeyError if it does not exist
NimaSarajpoor Jan 12, 2025
4b37b0a
add function to save cache
NimaSarajpoor Jan 13, 2025
baf3fea
Add note to function
NimaSarajpoor Jan 13, 2025
7d02173
fix format
NimaSarajpoor Jan 13, 2025
24bc232
replace fastmath flag with config variable
NimaSarajpoor Jan 13, 2025
f5c2718
add test function to check backward compatibility
NimaSarajpoor Jan 15, 2025
1a17346
skip test when JIT is disabled
NimaSarajpoor Jan 15, 2025
995a6c2
rename test function
NimaSarajpoor Jan 16, 2025
0097953
add conditional deprecation warning
NimaSarajpoor Jan 16, 2025
bee3b63
add test function to check if cache can be saved after cache._clear()
NimaSarajpoor Jan 16, 2025
8d29f91
remove old warning
NimaSarajpoor Jan 17, 2025
18dd4b9
add test for cache._clear
NimaSarajpoor Jan 17, 2025
d7b21a7
add wrapper around private functions
NimaSarajpoor Jan 17, 2025
cf4b183
Raise OSError when NUMBA JIT is disabled during cache save
NimaSarajpoor Jan 17, 2025
89825c9
move warnings to public API
NimaSarajpoor Jan 17, 2025
6a61483
fix warning message
NimaSarajpoor Jan 17, 2025
fd7eb7d
improved warning message
NimaSarajpoor Jan 18, 2025
9ba634b
Add commit about addition config variables that are defined in __init__
NimaSarajpoor Jan 18, 2025
1a48f3d
Revise test function to improve readability
NimaSarajpoor Jan 19, 2025
02115b2
Add test function for fastmath
NimaSarajpoor Jan 19, 2025
ecaead2
Revised test functions
NimaSarajpoor Jan 19, 2025
5c1a7a7
skip test if numba JIT is disabled
NimaSarajpoor Jan 19, 2025
89db05c
omit test functions that require NUMBA JIT
NimaSarajpoor Jan 19, 2025
7e72a7f
Removed the trivial test function
NimaSarajpoor Jan 20, 2025
b82b3c9
Raise warning instead of error to avoid interrupting the program
NimaSarajpoor Jan 20, 2025
9a69593
improve readability
NimaSarajpoor Jan 21, 2025
7e5985d
remove intermediate variable
NimaSarajpoor Jan 21, 2025
2369e33
minor fixes
NimaSarajpoor Jan 21, 2025
04eac83
Add shell script code to check for harcoded fastmath flags
NimaSarajpoor Jan 23, 2025
f5186a2
minor fix on indention
NimaSarajpoor Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
replaced hardcoded fastmath value with config var
  • Loading branch information
NimaSarajpoor committed Dec 26, 2024
commit e543379ae901a904c99fa49955c9276c8ac0748c
4 changes: 2 additions & 2 deletions stumpy/aamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@njit(
# "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8[:], i8, i8, i8, f8[:, :, :],"
# "f8[:, :], f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)",
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
NimaSarajpoor marked this conversation as resolved.
Show resolved Hide resolved
)
def _compute_diagonal(
T_A,
Expand Down Expand Up @@ -186,7 +186,7 @@ def _compute_diagonal(
@njit(
# "(f8[:], f8[:], i8, b1[:], b1[:], i8[:], b1, i8)",
parallel=True,
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _aamp(
T_A,
Expand Down
24 changes: 12 additions & 12 deletions stumpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def check_window_size(m, max_size=None):
raise ValueError(f"The window size must be less than or equal to {max_size}")


@njit(fastmath=True)
@njit(fastmath=config.STUMPY_FASTMATH)
def _sliding_dot_product(Q, T):
"""
A Numba JIT-compiled implementation of the sliding window dot product.
Expand Down Expand Up @@ -657,7 +657,7 @@ def sliding_dot_product(Q, T):

@njit(
# "f8[:](f8[:], i8, b1[:])",
fastmath={"nsz", "arcp", "contract", "afn", "reassoc"}
fastmath=config.STUMPY_FASTMATH_FLAGS
)
def _welford_nanvar(a, w, a_subseq_isfinite):
"""
Expand Down Expand Up @@ -771,7 +771,7 @@ def welford_nanstd(a, w=None):
return np.sqrt(np.clip(welford_nanvar(a, w), a_min=0, a_max=None))


@njit(parallel=True, fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
@njit(parallel=True, fastmath=config.STUMPY_FASTMATH_FLAGS)
def _rolling_nanstd_1d(a, w):
"""
A Numba JIT-compiled and parallelized function for computing the rolling standard
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def _calculate_squared_distance(

@njit(
# "f8[:](i8, f8[:], f8, f8, f8[:], f8[:])",
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _calculate_squared_distance_profile(
m, QT, μ_Q, σ_Q, M_T, Σ_T, Q_subseq_isconstant, T_subseq_isconstant
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def _calculate_squared_distance_profile(

@njit(
# "f8[:](i8, f8[:], f8, f8, f8[:], f8[:])",
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def calculate_distance_profile(
m, QT, μ_Q, σ_Q, M_T, Σ_T, Q_subseq_isconstant, T_subseq_isconstant
Expand Down Expand Up @@ -1229,7 +1229,7 @@ def calculate_distance_profile(
return np.sqrt(D_squared)


@njit(fastmath=True)
@njit(fastmath=config.STUMPY_FASTMATH)
def _p_norm_distance_profile(Q, T, p=2.0):
"""
A Numba JIT-compiled and parallelized function for computing the p-normalized
Expand Down Expand Up @@ -1505,7 +1505,7 @@ def mueen_calculate_distance_profile(Q, T):

@njit(
# "f8[:](f8[:], f8[:], f8[:], f8, f8, f8[:], f8[:])",
fastmath=True
fastmath=config.STUMPY_FASTMATH
)
def _mass(Q, T, QT, μ_Q, σ_Q, M_T, Σ_T, Q_subseq_isconstant, T_subseq_isconstant):
"""
Expand Down Expand Up @@ -1978,7 +1978,7 @@ def _get_QT(start, T_A, T_B, m):

@njit(
# ["(f8[:], i8, i8)", "(f8[:, :], i8, i8)"],
fastmath=True
fastmath=config.STUMPY_FASTMATH
)
def _apply_exclusion_zone(a, idx, excl_zone, val):
"""
Expand Down Expand Up @@ -2308,7 +2308,7 @@ def array_to_temp_file(a):

@njit(
# "i8[:](i8[:], i8, i8, i8)",
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _count_diagonal_ndist(diags, m, n_A, n_B):
"""
Expand Down Expand Up @@ -2505,7 +2505,7 @@ def rolling_isfinite(a, w):
)


@njit(parallel=True, fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
@njit(parallel=True, fastmath=config.STUMPY_FASTMATH_FLAGS)
def _rolling_isconstant(a, w):
"""
Compute the rolling isconstant for 1-D array.
Expand Down Expand Up @@ -2842,7 +2842,7 @@ def _idx_to_mp(
return P


@njit(fastmath=True)
@njit(fastmath=config.STUMPY_FASTMATH)
def _total_diagonal_ndists(tile_lower_diag, tile_upper_diag, tile_height, tile_width):
"""
Count the total number of distances covered by a range of diagonals
Expand Down Expand Up @@ -3970,7 +3970,7 @@ def _mdl(disc_subseqs, disc_neighbors, S, n_bit=8):

@njit(
# "(i8, i8, f8[:, :], f8[:], i8, f8[:, :], i8[:, :], f8)",
fastmath={"nsz", "arcp", "contract", "afn", "reassoc"},
fastmath=config.STUMPY_FASTMATH_FLAGS,
)
def _compute_multi_PI(d, idx, D, D_prime, range_start, P, I, p=2.0):
"""
Expand Down
2 changes: 1 addition & 1 deletion stumpy/maamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def _get_multi_p_norm(start, T, m, p=2.0):
# "(i8, i8, i8, f8[:, :], f8[:, :], i8, i8, b1[:, :], b1[:, :], f8,"
# "f8[:, :], f8[:, :], f8[:, :])",
parallel=True,
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _compute_multi_p_norm(
d,
Expand Down
2 changes: 1 addition & 1 deletion stumpy/mstump.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def _get_multi_QT(start, T, m):
# "(i8, i8, i8, f8[:, :], f8[:, :], i8, i8, f8[:, :], f8[:, :], f8[:, :],"
# "f8[:, :], f8[:, :], f8[:, :], f8[:, :])",
parallel=True,
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _compute_multi_D(
d,
Expand Down
4 changes: 2 additions & 2 deletions stumpy/scraamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _preprocess_prescraamp(T_A, m, T_B=None, s=None):
return (T_A, T_B, T_A_subseq_isfinite, T_B_subseq_isfinite, indices, s, excl_zone)


@njit(fastmath=True)
@njit(fastmath=config.STUMPY_FASTMATH)
def _compute_PI(
T_A,
T_B,
Expand Down Expand Up @@ -286,7 +286,7 @@ def _compute_PI(
# "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8, i8, f8[:], f8[:],"
# "i8[:], optional(i8))",
parallel=True,
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _prescraamp(
T_A,
Expand Down
4 changes: 2 additions & 2 deletions stumpy/scrump.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _preprocess_prescrump(
)


@njit(fastmath=True)
@njit(fastmath=config.STUMPY_FASTMATH)
def _compute_PI(
T_A,
T_B,
Expand Down Expand Up @@ -384,7 +384,7 @@ def _compute_PI(
# "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], i8, i8, f8[:], f8[:],"
# "i8[:], optional(i8))",
parallel=True,
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _prescrump(
T_A,
Expand Down
4 changes: 2 additions & 2 deletions stumpy/stump.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], f8[:],"
# "b1[:], b1[:], b1[:], b1[:], i8[:], i8, i8, i8, f8[:, :, :], f8[:, :],"
# "f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)",
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _compute_diagonal(
T_A,
Expand Down Expand Up @@ -247,7 +247,7 @@ def _compute_diagonal(
# "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], b1[:], b1[:],"
# "b1[:], b1[:], i8[:], b1, i8)",
parallel=True,
fastmath=True,
fastmath=config.STUMPY_FASTMATH,
)
def _stump(
T_A,
Expand Down
Loading