Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fasthmath precision issue #1048

Merged
merged 69 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
0d4abf0
removed reassoc flag from fastmath
NimaSarajpoor Dec 5, 2024
7b9a2a3
Add reset feature to config
NimaSarajpoor Dec 22, 2024
95e7d8f
Revised config value
NimaSarajpoor Dec 22, 2024
1f95f77
replaced fastmath flags with config var
NimaSarajpoor Dec 22, 2024
37add84
fixed format
NimaSarajpoor Dec 22, 2024
c795b9d
Removed bad f-string
NimaSarajpoor Dec 22, 2024
2774302
Replaced Raised with Returns in docstring
NimaSarajpoor Dec 22, 2024
80544c7
Add second attempt for assertion
NimaSarajpoor Dec 22, 2024
a080495
minor change
NimaSarajpoor Dec 22, 2024
976ec13
Add condition to avoid revising fastmath when JIT is disabled
NimaSarajpoor Dec 22, 2024
b24ff9d
Removed support for input with type list to simplify function
NimaSarajpoor Dec 23, 2024
d13a32b
Refactored the recompile process
NimaSarajpoor Dec 24, 2024
4256839
removed blank lines
NimaSarajpoor Dec 24, 2024
f17f92d
fixed typo
NimaSarajpoor Dec 26, 2024
e543379
replaced hardcoded fastmath value with config var
NimaSarajpoor Dec 26, 2024
71fe3aa
revised function
NimaSarajpoor Dec 26, 2024
6854b0e
renamed variable to improve readability
NimaSarajpoor Dec 26, 2024
aeea9b4
fixed bug
NimaSarajpoor Dec 26, 2024
b41cc66
rename config to improve readability
NimaSarajpoor Dec 28, 2024
02b1fb4
revise func clear
NimaSarajpoor Dec 28, 2024
9647113
revise func to recompile all njit functions
NimaSarajpoor Dec 28, 2024
245bfc4
Adapt to changes in test function
NimaSarajpoor Dec 28, 2024
7dd15eb
add test
NimaSarajpoor Dec 28, 2024
4b52ff5
resolve coverage
NimaSarajpoor Dec 28, 2024
aabd41a
resolve missing lines in coverage
NimaSarajpoor Dec 28, 2024
47c9932
Add test function to improve coverage
NimaSarajpoor Dec 28, 2024
816e8d4
add fastmath module
NimaSarajpoor Dec 29, 2024
562550e
revise test function to use fastmath module
NimaSarajpoor Dec 29, 2024
d99148a
fix minor issue
NimaSarajpoor Dec 29, 2024
8a45f6e
minor change to improve readability
NimaSarajpoor Dec 29, 2024
7eb52c1
Add fastmath default flags to config default
NimaSarajpoor Jan 4, 2025
b130fa7
add reset function
NimaSarajpoor Jan 4, 2025
6264e94
rename function
NimaSarajpoor Jan 4, 2025
1d873bd
adapt recent changes in test function
NimaSarajpoor Jan 4, 2025
127e61c
minor fixes
NimaSarajpoor Jan 4, 2025
447c006
Check if DISABLE_JIT before getting fastmath
NimaSarajpoor Jan 4, 2025
14c2267
ignore lines for coverage check
NimaSarajpoor Jan 4, 2025
ec960bd
Merge branch 'main' into investigate_precision_failure
NimaSarajpoor Jan 12, 2025
2469458
Editorial fix
NimaSarajpoor Jan 12, 2025
2624445
avoid .get(key) to get KeyError if it does not exist
NimaSarajpoor Jan 12, 2025
4b37b0a
add function to save cache
NimaSarajpoor Jan 13, 2025
baf3fea
Add note to function
NimaSarajpoor Jan 13, 2025
7d02173
fix format
NimaSarajpoor Jan 13, 2025
24bc232
replace fastmath flag with config variable
NimaSarajpoor Jan 13, 2025
f5c2718
add test function to check backward compatibility
NimaSarajpoor Jan 15, 2025
1a17346
skip test when JIT is disabled
NimaSarajpoor Jan 15, 2025
995a6c2
rename test function
NimaSarajpoor Jan 16, 2025
0097953
add conditional deprecation warning
NimaSarajpoor Jan 16, 2025
bee3b63
add test function to check if cache can be saved after cache._clear()
NimaSarajpoor Jan 16, 2025
8d29f91
remove old warning
NimaSarajpoor Jan 17, 2025
18dd4b9
add test for cache._clear
NimaSarajpoor Jan 17, 2025
d7b21a7
add wrapper around private functions
NimaSarajpoor Jan 17, 2025
cf4b183
Raise OSError when NUMBA JIT is disabled during cache save
NimaSarajpoor Jan 17, 2025
89825c9
move warnings to public API
NimaSarajpoor Jan 17, 2025
6a61483
fix warning message
NimaSarajpoor Jan 17, 2025
fd7eb7d
improved warning message
NimaSarajpoor Jan 18, 2025
9ba634b
Add commit about addition config variables that are defined in __init__
NimaSarajpoor Jan 18, 2025
1a48f3d
Revise test function to improve readability
NimaSarajpoor Jan 19, 2025
02115b2
Add test function for fastmath
NimaSarajpoor Jan 19, 2025
ecaead2
Revised test functions
NimaSarajpoor Jan 19, 2025
5c1a7a7
skip test if numba JIT is disabled
NimaSarajpoor Jan 19, 2025
89db05c
omit test functions that require NUMBA JIT
NimaSarajpoor Jan 19, 2025
7e72a7f
Removed the trivial test function
NimaSarajpoor Jan 20, 2025
b82b3c9
Raise warning instead of error to avoid interrupting the program
NimaSarajpoor Jan 20, 2025
9a69593
improve readability
NimaSarajpoor Jan 21, 2025
7e5985d
remove intermediate variable
NimaSarajpoor Jan 21, 2025
2369e33
minor fixes
NimaSarajpoor Jan 21, 2025
04eac83
Add shell script code to check for harcoded fastmath flags
NimaSarajpoor Jan 23, 2025
f5186a2
minor fix on indention
NimaSarajpoor Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions stumpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import importlib
import os.path
from importlib.metadata import distribution
from site import getsitepackages

import numba
from numba import cuda

from . import cache, config
from .aamp import aamp # noqa: F401
from .aamp_mmotifs import aamp_mmotifs # noqa: F401
from .aamp_motifs import aamp_match, aamp_motifs # noqa: F401
Expand Down Expand Up @@ -32,6 +35,18 @@
from .stumped import stumped # noqa: F401
from .stumpi import stumpi # noqa: F401

# Get the default fastmath flags for all njit functions
# and update the _STUMPY_DEFAULTS dictionary

if not numba.config.DISABLE_JIT: # pragma: no cover
NimaSarajpoor marked this conversation as resolved.
Show resolved Hide resolved
njit_funcs = cache.get_njit_funcs()
for module_name, func_name in njit_funcs:
module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
key = module_name + "." + func_name # e.g., core._mass
key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS
config._STUMPY_DEFAULTS[key] = func.targetoptions["fastmath"]

if cuda.is_available():
from .gpu_aamp import gpu_aamp # noqa: F401
from .gpu_aamp_ostinato import gpu_aamp_ostinato # noqa: F401
Expand Down
4 changes: 2 additions & 2 deletions stumpy/aamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@njit(
# "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8[:], i8, i8, i8, f8[:, :, :],"
# "f8[:, :], f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)",
fastmath=True,
fastmath=config.STUMPY_FASTMATH_TRUE,
)
def _compute_diagonal(
T_A,
Expand Down Expand Up @@ -186,7 +186,7 @@ def _compute_diagonal(
@njit(
# "(f8[:], f8[:], i8, b1[:], b1[:], i8[:], b1, i8)",
parallel=True,
fastmath=True,
fastmath=config.STUMPY_FASTMATH_TRUE,
)
def _aamp(
T_A,
Expand Down
98 changes: 96 additions & 2 deletions stumpy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

import ast
import importlib
import inspect
import pathlib
import site
import warnings

import numba

CACHE_WARNING = "Caching `numba` functions is purely for experimental purposes "
CACHE_WARNING += "and should never be used or depended upon as it is not supported! "
CACHE_WARNING += "All caching capabilities are not tested and may be removed/changed "
Expand Down Expand Up @@ -74,7 +77,15 @@ def _enable():
-------
None
"""
warnings.warn(CACHE_WARNING)
frame = inspect.currentframe()
caller_name = inspect.getouterframes(frame)[1].function
if caller_name != "_save":
msg = (
"The 'cache._enable()' function is deprecated and no longer supported. "
+ "Please use 'cache.save()' instead"
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)

njit_funcs = get_njit_funcs()
for module_name, func_name in njit_funcs:
module = importlib.import_module(f".{module_name}", package="stumpy")
Expand All @@ -94,12 +105,29 @@ def _clear():
-------
None
"""
warnings.warn(CACHE_WARNING)
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
[f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]


def clear():
"""
Clear numba cache directory

Parameters
----------
None

Returns
-------
None
"""
warnings.warn(CACHE_WARNING)
_clear()

return


def _get_cache():
"""
Retrieve a list of cached numba functions
Expand All @@ -117,3 +145,69 @@ def _get_cache():
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]


def _recompile():
"""
Recompile all njit functions

Parameters
----------
None

Returns
-------
None

Notes
-----
If the `numba` cache is enabled, this results in saving (and/or overwriting)
the cached numba functions to disk.
"""
NimaSarajpoor marked this conversation as resolved.
Show resolved Hide resolved
for module_name, func_name in get_njit_funcs():
module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
func.recompile()

return


def _save():
"""
Save all njit functions

Parameters
----------
None

Returns
-------
None
"""
_enable()
_recompile()

return


def save():
"""
Save/overwrite all the cache data files of
all-so-far compiled njit functions.

Parameters
----------
None

Returns
-------
None
"""
NimaSarajpoor marked this conversation as resolved.
Show resolved Hide resolved
if numba.config.DISABLE_JIT:
msg = "Could not save/cache function because NUMBA JIT is disabled"
warnings.warn(msg)
else:
warnings.warn(CACHE_WARNING)
_save()

return
77 changes: 67 additions & 10 deletions stumpy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,72 @@
# Copyright 2019 TD Ameritrade. Released under the terms of the 3-Clause BSD license.
# STUMPY is a trademark of TD Ameritrade IP Company, Inc. All rights reserved.

import warnings

import numpy as np

STUMPY_THREADS_PER_BLOCK = 512
STUMPY_MEAN_STD_NUM_CHUNKS = 1
STUMPY_MEAN_STD_MAX_ITER = 10
STUMPY_DENOM_THRESHOLD = 1e-14
STUMPY_STDDEV_THRESHOLD = 1e-7
STUMPY_P_NORM_THRESHOLD = 1e-14
STUMPY_TEST_PRECISION = 5
STUMPY_MAX_P_NORM_DISTANCE = np.finfo(np.float64).max
STUMPY_MAX_DISTANCE = np.sqrt(STUMPY_MAX_P_NORM_DISTANCE)
STUMPY_EXCL_ZONE_DENOM = 4
_STUMPY_DEFAULTS = {
"STUMPY_THREADS_PER_BLOCK": 512,
"STUMPY_MEAN_STD_NUM_CHUNKS": 1,
"STUMPY_MEAN_STD_MAX_ITER": 10,
"STUMPY_DENOM_THRESHOLD": 1e-14,
"STUMPY_STDDEV_THRESHOLD": 1e-7,
"STUMPY_P_NORM_THRESHOLD": 1e-14,
"STUMPY_TEST_PRECISION": 5,
"STUMPY_MAX_P_NORM_DISTANCE": np.finfo(np.float64).max,
"STUMPY_MAX_DISTANCE": np.sqrt(np.finfo(np.float64).max),
"STUMPY_EXCL_ZONE_DENOM": 4,
"STUMPY_FASTMATH_TRUE": True,
"STUMPY_FASTMATH_FLAGS": {"nsz", "arcp", "contract", "afn", "reassoc"},
}

# In addition to these configuration variables, there exist config variables
# that have the default value of the fastmath flag of the njit functions. The
# name of this config variable has the following format:
# STUMPY_FASTMATH_<module_name>.<function_name>
# See __init__.py for more details

STUMPY_THREADS_PER_BLOCK = _STUMPY_DEFAULTS["STUMPY_THREADS_PER_BLOCK"]
STUMPY_MEAN_STD_NUM_CHUNKS = _STUMPY_DEFAULTS["STUMPY_MEAN_STD_NUM_CHUNKS"]
STUMPY_MEAN_STD_MAX_ITER = _STUMPY_DEFAULTS["STUMPY_MEAN_STD_MAX_ITER"]
STUMPY_DENOM_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_DENOM_THRESHOLD"]
STUMPY_STDDEV_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_STDDEV_THRESHOLD"]
STUMPY_P_NORM_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_P_NORM_THRESHOLD"]
STUMPY_TEST_PRECISION = _STUMPY_DEFAULTS["STUMPY_TEST_PRECISION"]
STUMPY_MAX_P_NORM_DISTANCE = _STUMPY_DEFAULTS["STUMPY_MAX_P_NORM_DISTANCE"]
STUMPY_MAX_DISTANCE = _STUMPY_DEFAULTS["STUMPY_MAX_DISTANCE"]
STUMPY_EXCL_ZONE_DENOM = _STUMPY_DEFAULTS["STUMPY_EXCL_ZONE_DENOM"]
STUMPY_FASTMATH_TRUE = _STUMPY_DEFAULTS["STUMPY_FASTMATH_TRUE"]
STUMPY_FASTMATH_FLAGS = _STUMPY_DEFAULTS["STUMPY_FASTMATH_FLAGS"]
NimaSarajpoor marked this conversation as resolved.
Show resolved Hide resolved


def _reset(var=None):
"""
Reset the value of a configuration variable(s) to their default value(s)

Parameters
----------
var : str, default None
The name of the configuration variable. If None, then all
configuration variables are reset to their default values.

Returns
-------
None
NimaSarajpoor marked this conversation as resolved.
Show resolved Hide resolved
"""
config_vars = [
k for k, _ in globals().items() if k.isupper() and k.startswith("STUMPY")
]

if var is None:
for config_var in config_vars:
globals()[config_var] = _STUMPY_DEFAULTS[config_var]
elif var in config_vars:
globals()[var] = _STUMPY_DEFAULTS[var]
else: # pragma: no cover
msg = (
f"Configuration reset was skipped for unrecognized '_STUMPY_DEFAULT[{var}]'"
)
warnings.warn(msg)

return
Loading
Loading