Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fasthmath precision issue #1048

Merged
merged 69 commits into from
Jan 23, 2025

Conversation

NimaSarajpoor
Copy link
Collaborator

See: stumpy-dev/automate#3
See MWE: stumpy-dev/automate#3 (comment)

The tests are passing for both numba version 0.60 and 0.61.0rc1

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Dec 5, 2024

@seanlaw
Not sure what would be a better approach here. Did I need to push an empty commit to just first trigger the Github Actions here in this repo to just see the error? But I noticed the version 0.61.0rc1 is not being installed in the Github Actions. I feel I am doing something incorrectly here.

Copy link

codecov bot commented Dec 5, 2024

Codecov Report

Attention: Patch coverage is 55.22388% with 60 lines in your changes missing coverage. Please review.

Project coverage is 96.70%. Comparing base (70e4e70) to head (f5186a2).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
stumpy/cache.py 21.42% 22 Missing ⚠️
tests/test_fastmath.py 26.08% 17 Missing ⚠️
tests/test_cache.py 35.29% 11 Missing ⚠️
stumpy/fastmath.py 33.33% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1048      +/-   ##
==========================================
- Coverage   96.97%   96.70%   -0.28%     
==========================================
  Files          90       93       +3     
  Lines       15104    15214     +110     
==========================================
+ Hits        14647    14712      +65     
- Misses        457      502      +45     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw
Please ignore my previous comment.

Can you please check out the change and merge the branch if it is all good? I tested the change locally and all unit tests are passing when Numba==0.61.0rc1. To be on the safe side, I also tried the unit tests in Colab to check the GPU-required test functions as well.

@seanlaw
Copy link
Contributor

seanlaw commented Dec 6, 2024

@NimaSarajpoor As I think about it more, I am starting to question whether turning off reassoc is a good idea because we're getting divergance in how we use fastmath. Sometimes we have:

  1. fastmath=True
  2. fastmath={"nsz", "arcp", "contract", "afn", "reassoc"}

and now we are adding a third variant fastmath={"nsz", "arcp", "contract", "afn"}. This is bothersome to maintain and smells fishy IMHO. I feel like our North Star should probably be to match our fast implementation with our naive implementations (i.e., what would a user reasonably expect when they try to calculate things themselves using numpy or base Python)? What do you think?

@seanlaw
Copy link
Contributor

seanlaw commented Dec 6, 2024

A small part of me is questioning whether we should/could update the unit test but I want to avoid it as it may hide other problems. This is a tough one!

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Dec 7, 2024

And now we are adding a third variant fastmath={"nsz", "arcp", "contract", "afn"}. This is bothersome to maintain and smells fishy IMHO

Correct. What if we remove reassoc entirely whenever we use fastmath?
fastmath={"nsz", "arcp", "contract", "afn"}
?

We can try to measure how much performance is hit by that. I was looking at the description of flags. While there might be other flags that may affect a floating-point value, reassoc is the only one whose description clearly says:

This may dramatically change results in floating-point.

Since STUMPY tries to have a good precision, I was wondering if we should just remove the flag reassoc entirely from codebase.


A small part of me is questioning whether we should/could update the unit test

This triggered me to look into the STUMPY's unit tests. I noticed we use stumpy.config.STUMPY_TEST_PRECISION when we want to compare values like matrix profile with the ones obtained via naive approach. I also noticed that the current issue in snippets come from a slight loss of precision in mpdist value between two subsequences . IIRC, mpdist between two subsequences A and B is computed based on the AB-join and BA-join matrix profiles. So, the values of numpy array D computed here

stumpy/stumpy/snippets.py

Lines 288 to 296 in 3165d1c

D = _get_all_profiles(
T,
m,
percentage=percentage,
s=s,
mpdist_percentage=mpdist_percentage,
mpdist_k=mpdist_k,
mpdist_T_subseq_isconstant=mpdist_T_subseq_isconstant,
)

should be trusted up to stumpy.config.STUMPY_TEST_PRECISION

So, we can create a private function _snippet that takes D as input. Then, in unit testing, we can assert D as well as the output of _snippet. I am not sure though how to test the function snippet as a whole.

(Btw, this approach sounds weird to me!)

@seanlaw
Copy link
Contributor

seanlaw commented Dec 8, 2024

We can try to measure how much performance is hit by that. I was looking at the description of flags. While there might be other flags that may affect a floating-point value, reassoc is the only one whose description clearly says:
Since STUMPY tries to have a good precision, I was wondering if we should just remove the flag reassoc entirely from codebase.

Yes, I think we should look at the performance and see how much of a hit we take. If it's less than 5% difference then it might be a good idea to turn it off. We might be able to give power users the option of making things faster setting a global configuration variable.

(Btw, this approach sounds weird to me!)

Agreed. Let's check the performance first and at least be armed with data to make a decision

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Dec 10, 2024

Is fastmath=True equivalent to fastmath={"nnan", "ninf", "nsz", "arcp", "contract", "afn", "reassoc"}?


I measured the impact of removing reassoc on the running time of stumpy.stump. Currently, there are two njit-decorated functions in that script, both with fastmath=True. I considered three different versions of stumpy.stump

(1) fastmath is True: Used the current version of STUMPY, i.e. v1.13.0

(2) fastmath 7 flags: Replaced fastmath=True with fastmath={"nnan", "ninf", "nsz", "arcp", "contract", "afn", "reassoc"}

(3) fastmath 6 flags: Replaced fastmath=True with fastmath={"nnan", "ninf", "nsz", "arcp", "contract", "afn"}

It turns out that there is a considerable impact on the running time when we switch from (1) to (2).

fastmath_impact

fastmath_impact_ratio


# code for measuring running time

import sys
import time

import stumpy
import numpy as np

seed = 0
np.random.seed(seed)


N = 10
n_values = 10000 * np.arange(1, N + 1)
m = 50

n_iter = 5

# dummy run
stumpy.stump(np.random.rand(100), m)

out = np.full(len(n_values), -1.0, dtype=np.float64)
for i, n in enumerate(n_values):
    print(f'{i + 1}/{len(n_values)}')

    T = np.random.rand(n)
    
    lst = []
    for _ in range(n_iter):
        tic = time.time() 
        stumpy.stump(T, m)     
        toc = time.time()
        lst.append(toc - tic)
   
    out[i] = np.median(lst)

name = sys.argv[1]
np.save(f"{name}.npy", out)

@seanlaw
Copy link
Contributor

seanlaw commented Dec 11, 2024

According to numba's fastmath docs:

In certain classes of applications strict IEEE 754 compliance is less important. As a result it is possible to relax some numerical rigour with view of gaining additional performance.

@NimaSarajpoor I'm struggling a little to think through what our goal is so maybe you can help me. It's clear that ANY use of fastmath flags will:

  1. Make STUMPY less compliant with strict IEEE 754
  2. Make STUMPY faster

reassoc
Allow reassociation transformations for floating-point instructions. This may dramatically change results in floating-point.

Additionally, the use of reassoc may affect the FP results. These are the basic facts.

So, if I think about it from a user's perspective, 99.9% of users won't care about the imprecision/IEEE 754 compliance as they likely care about performance way more. However, for the 0.1% of users that care, we should have a way to allow them to easily set/unset custom fastmath flags. Again, just thinking out loud here. This same mechanism should allow us to set/unset custom fastmath flags too for precision unit testing purposes too!

In other words, I think the default fastmath flags should ALWAYS be (where possible):

  1. fastmath=True
  2. fastmath={"nnan", "ninf", "nsz", "arcp", "contract", "afn", "reassoc"}
  3. `fastmath=SOME_CUSTOM_FLAGS

However, the issue is that we currently have a mixture of fastmath=True and fastmath={"nnan", "ninf", "nsz", "arcp", "contract", "afn", "reassoc"}. Some functions use the first flag and others use the second (verbose) set of flags. So, it's not straightforward to, say, temporarily control both of them (in the case of our unit test) and then be able to switch back to the default flags that correspond to each function. If we can solve that gracefully, then I think we are moving toward the "right" solution. My gut tells me that there has to be a nice simple way to do this...

What do you think @NimaSarajpoor?

@seanlaw
Copy link
Contributor

seanlaw commented Dec 11, 2024

One possible solution would be to add default configurations to the config.py module. For example

# config.py
import numpy as np
import warnings

_STUMPY_DEFAULTS = {
    "STUMPY_THREADS_PER_BLOCK": 512,
    "STUMPY_MEAN_STD_NUM_CHUNKS": 1,
    "STUMPY_MEAN_STD_MAX_ITER": 10,
    "STUMPY_DENOM_THRESHOLD": 1e-14,
    "STUMPY_STDDEV_THRESHOLD": 1e-7,
    "STUMPY_P_NORM_THRESHOLD": 1e-14,
    "STUMPY_TEST_PRECISION": 5,
    "STUMPY_MAX_P_NORM_DISTANCE": np.finfo(np.float64).max,
    "STUMPY_MAX_DISTANCE": np.sqrt(np.finfo(np.float64).max),
    "STUMPY_EXCL_ZONE_DENOM": 4,
    "STUMPY_FASTMATH": True,
    "STUMPY_FASTMATH_FLAGS": {"nnan", "ninf", "nsz", "arcp", "contract", "afn", "reassoc"},
}

STUMPY_THREADS_PER_BLOCK = _STUMPY_DEFAULTS["STUMPY_THREADS_PER_BLOCK"]
STUMPY_MEAN_STD_NUM_CHUNKS = _STUMPY_DEFAULTS["STUMPY_MEAN_STD_NUM_CHUNKS"]
STUMPY_MEAN_STD_MAX_ITER = _STUMPY_DEFAULTS["STUMPY_MEAN_STD_MAX_ITER"]
STUMPY_DENOM_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_DENOM_THRESHOLD"]
STUMPY_STDDEV_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_STDDEV_THRESHOLD"]
STUMPY_P_NORM_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_P_NORM_THRESHOLD"]
STUMPY_TEST_PRECISION = _STUMPY_DEFAULTS["STUMPY_TEST_PRECISION"]
STUMPY_MAX_P_NORM_DISTANCE = _STUMPY_DEFAULTS["STUMPY_MAX_P_NORM_DISTANCE"]
STUMPY_MAX_DISTANCE = _STUMPY_DEFAULTS["STUMPY_MAX_DISTANCE"]
STUMPY_EXCL_ZONE_DENOM = _STUMPY_DEFAULTS["STUMPY_EXCL_ZONE_DENOM"]
STUMPY_FASTMATH = _STUMPY_DEFAULTS["STUMPY_FASTMATH"]
STUMPY_FASTMATH_FLAGS = _STUMPY_DEFAULTS["STUMPY_FASTMATH_FLAGS"]


def _reset(var=None):
    config_vars = [
            k
            for k, v in globals().items()
            if k.isupper() and k.startswith("STUMPY")
    ]

    if var is None:
        # Reset all config variables back to default values
        for var in config_vars:
            globals()[var] = _STUMPY_DEFAULTS[var]
    else:
        if var in config_vars:
            globals()[var] = _STUMPY_DEFAULTS[var]
        else:
            msg = f"Could not reset unrecognized configuration variable \"{var}\""
            warnings.warn(msg)

Then, when you do something like:

import config

config.STUMPY_EXCL_ZONE_DENOM = 17
print(config.STUMPY_EXCL_ZONE_DENOM)  # Prints "17"
config._reset()
print(config.STUMPY_EXCL_ZONE_DENOM)  # Prints "4"

config.STUMPY_EXCL_ZONE_DENOM = 17
print(config.STUMPY_EXCL_ZONE_DENOM)  # Prints "17"
config._reset("STUMPY_EXCL_ZONE")  # Causes warning due to incorrect spelling
# /Users/slaw/config.py:53: UserWarning: Could not reset unrecognized configuration variable "STUMPY_EXCL_ZONE"
  warnings.warn(msg)
config._reset("STUMPY_EXCL_ZONE_DENOM")
print(config.STUMPY_EXCL_ZONE_DENOM  # Prints "4"

Note:

  1. I have not tested this thoroughly (or in the context of the full STUMPY package)
  2. Messing with globals() is usually a very bad idea and so I hope there is a "safer" approach
  3. Then, for our snippets precision unit test, we can temporarily change config.STUMPY_FASTMATH_FLAGS = {"nnan", "ninf", "nsz", "arcp", "contract", "afn"} (no reassoc), perform the test, and then config._reset(). Though, I don't know if that is what we want to do OR to simply change our unit test cmp values? Maybe we add another test for with/without reassoc? I feel like this reassoc thing is going to come back in the future... Of course, it doesn't seem to have a significant performance difference compared with when reassoc is present/absent (not comparing to fastmath=True).

I'd like to hear your thoughts on this @NimaSarajpoor.

@NimaSarajpoor
Copy link
Collaborator Author

I feel like our North Star should probably be to match our fast implementation with our naive implementations (i.e., what would a user reasonably expect when they try to calculate things themselves using numpy or base Python)? What do you think?

So, if I think about it from a user's perspective, 99.9% of users won't care about the imprecision/IEEE 754 compliance as they likely care about performance way more. However, for the 0.1% of users that care, we should have a way to allow them to easily set/unset custom fastmath flags.

Then, for our snippets precision unit test, we can temporarily change config.STUMPY_FASTMATH_FLAGS = {"nnan", "ninf", "nsz", "arcp", "contract", "afn"} (no reassoc), perform the test, and then config._reset(). Though, I don't know if that is what we want to do OR to simply change our unit test cmp values?

Thanks for detailed explanation! I do agree for the most part and was wondering if such manipulation can happen in stumpy.snippets rather than its unit tests. I totally agree that most users do not care about the slight loss of precision. Therefore: It is better to use fastmath=True, and, if not possible, use
fastmath={ "nsz", "arcp", "contract", "afn", "reassoc"} as the next best option, which is what STUMPY does. (Not sure but I believe this is what you meant). And, use some custom flags if needed.

Now here is the tricky part. On one hand, we understand that there can be slight loss of precision and, logically speaking, we should not compare them with naive version. On the other hand, we cannot just manipulate flags in the unit testing since, at the end of the day, users use a code that has those flags by default. This is like a paradox. we currently ASSUME that the loss of precision is so little that it does not affects the unit tests, and we include flags in the unit testing, which is what STUMPY does. Now, we should ACCEPT that the ASSUMPTION can be wrong sometimes, like in snippets case. So:

  • Option 1: Revise snippet? (change flag in the beginning, and reset it at the end)
  • Option 2: Revise snippet's test functions?

I was wondering if we should go with Option 1 because the goal should be having a unit test that checks the reliability of a code that is being used by user. But, I think if flag is changed via config within snippet, it will affect all other processes that are happening at that time. Correct? If yes, then I think we need to go with Option 2, which is related to what you mentioned in your previous comment.

Then, for our snippets precision unit test, we can temporarily change config.STUMPY_FASTMATH_FLAGS = {"nnan", "ninf", "nsz", "arcp", "contract", "afn"} (no reassoc), perform the test, and then config._reset(). Though, I don't know if that is what we want to do OR to simply change our unit test cmp values? Maybe we add another test for with/without reassoc? I feel like this reassoc thing is going to come back in the future... Of course, it doesn't seem to have a significant performance difference compared with when reassoc is present/absent (not comparing to fastmath=True).

@seanlaw
Copy link
Contributor

seanlaw commented Dec 13, 2024

This is like a paradox.

Yes, and it's annoying as hell! Maybe we can do something like

import pytest
import numpy as np
import numpy.testing as npt
from stumpy import config

def test_my_function():
    cmp_with_reassoc = my_function()
    config.STUMPY_FASTMATH_FLAGS={"nnan", "ninf", "nsz", "arcp", "contract", "afn"}
    cmp_no_reassoc = my_function()
    config._reset("STUMPY_FASTMATH_FLAGS")

    ref = naive.my_function()

    if np.isclose(ref, cmp_with_reassoc):
        np.assert_almost_equal(ref, cmp_with_reassoc)
    else:
        np.assert_almost_equal(ref, cmp_no_reassoc)

It's a little bit of cheating in that we "compare" the computed values BUT we don't assert until we know they match. Thoughts?

What I don't know is whether or not my_function will actually get recompiled as a result of the flags changing.

@NimaSarajpoor
Copy link
Collaborator Author

It's a little bit of cheating in that we "compare" the computed values BUT we don't assert until we know they match. Thoughts?

This is a nice idea! This is like trying the next best assertion.

What I don't know is whether or not my_function will actually get recompiled as a result of the flags changing.

I tried and it failed. I am reading some relevant discussion on numba Discourse. One suggested solution was to get the python function and create new njit function:

njit(
    fastmath= {"nsz", "arcp", "contract", "afn"}
)(core._calculate_squared_distance.py_func)

However this does not solve our problem. because, I think, core._calculate_squared_distance is a callee but the caller needs to be recompiled as well. Now I am trying to find a way to do "(re)compile all".

As a side, there is one hacky solution here but I prefer to avoid it for now as this creates code branching.

@seanlaw
Copy link
Contributor

seanlaw commented Dec 14, 2024

@NimaSarajpoor What about clearing the cache via cache._clear()? You'd need to call it again at the end after calling config._reset(). I wonder if that work?

@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw
cache._clear did not work.

from stumpy import cache

config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"}

cache._clear()
... = stumpy.snippets(...)

Going to revert my changes and check again to make sure I did not miss anything.

@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw
I created a simpler test function to see if cache._clear() can help us with (re)compiling. Strangely it didn't.

Note: core._calculate_squared_distance is revised to have the decorator @njit(fastmath=config.STUMPY_FASTMATH_FLAGS)

def test_temp():
    s = 3
    QT = 1453258.774722079
    μ_Q = -55.35781464461123
    σ_Q = 650.912209452633
    μ_T = -264.9540642227162
    σ_T = 722.0717285148526
    isconstant_Q = False
    isconstant_T = False

    ref = core._calculate_squared_distance(
        s, QT, μ_Q, σ_Q, μ_T, σ_T, isconstant_Q, isconstant_T
    )

    comp = core._calculate_squared_distance(
        s, QT, μ_T, σ_T, μ_Q, σ_Q, isconstant_T, isconstant_Q
    )

    assert ref - comp == 0.0

This assertion fails. And this test:

....
def test_temp():
  ...
  ref = core._calculate_squared_distance(...)
  
  config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"}  
  cache._clear()
  
  comp = core._calculate_squared_distance(...)
  assert ref - comp == 0.0

also fails.

@seanlaw
Copy link
Contributor

seanlaw commented Dec 14, 2024

Note: core._calculate_squared_distance is revised to have the decorator @njit(fastmath=config.STUMPY_FASTMATH_FLAGS)

Great!

I noticed that each numba decorated function has an attribute called targetoptions and a method called recompile()

from stumpy.stump import _stump
 
print(_stump.targetoptions)
# {'parallel': True, 'fastmath': True, 'nopython': True, 'boundscheck': None}

Not sure if this is helpful.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Dec 15, 2024

@seanlaw
Thanks for sharing!! I got some interesting results!

(1) .recompile() helps when it comes to testing the core._calculate_squared_distance. For instance, the assertion in the following FAILs:

s = 3
QT = 1453258.774722079
μ_Q = -55.35781464461123
σ_Q = 650.912209452633
μ_T = -264.9540642227162
σ_T = 722.0717285148526
isconstant_Q = False
isconstant_T = False

args = (s, QT, μ_Q, σ_Q, μ_T, σ_T, isconstant_Q, isconstant_T)
args_reversed = (s, QT, μ_T, σ_T, μ_Q, σ_Q, isconstant_T, isconstant_Q)

ref = core._calculate_squared_distance(*args)
comp = core._calculate_squared_distance(*args_reversed)

assert ref == comp

Now, if I ADD the following two lines just before computing ref and comp, the assertion will PASS.

core._calculate_squared_distance.targetoptions["fastmath"] = {"nsz", "arcp", "contract", "afn"}
core._calculate_squared_distance.recompile()

(2) Unfortunately, adding those two lines above to the test_precision.py::test_snippets is NOT enough for resolving the issue. It turns out that all njit-decorated caller functions need to be recompiled. So, the following code seems to work for me!

# in test_precision.py::test_snippets

...

core._calculate_squared_distance.targetoptions["fastmath"] = {
    "nsz",
    "arcp",
    "contract",
    "afn",
}
core._calculate_squared_distance.recompile()
core._calculate_squared_distance_profile.recompile()
core.calculate_distance_profile.recompile()
core._mass.recompile()
...

@seanlaw
Copy link
Contributor

seanlaw commented Dec 15, 2024

@NimaSarajpoor In case it may be helpful, you could possibly use this function to help recompile everything:

print(cache.get_njit_funcs())

[('aamp', '_compute_diagonal'),
('aamp', '_aamp'),
('core', '_sliding_dot_product'),
('core', '_welford_nanvar'),
('core', '_rolling_nanstd_1d'),
('core', '_calculate_squared_distance'),
('core', '_calculate_squared_distance_profile'),
('core', 'calculate_distance_profile'),
('core', '_p_norm_distance_profile'),
('core', '_mass'),
('core', '_apply_exclusion_zone'),
('core', '_count_diagonal_ndist'),
('core', '_get_array_ranges'),
('core', '_get_ranges'),
('core', '_rolling_isconstant'),
('core', '_total_diagonal_ndists'),
('core', '_merge_topk_PI'),
('core', '_merge_topk_ρI'),
('core', '_shift_insert_at_index'),
('core', '_compute_multi_PI'),
('maamp', '_compute_multi_p_norm'),
('mstump', '_compute_multi_D'),
('scraamp', '_compute_PI'),
('scraamp', '_prescraamp'),
('scrump', '_compute_PI'),
('scrump', '_prescrump'),
('stump', '_compute_diagonal'),
('stump', '_stump')]

@seanlaw
Copy link
Contributor

seanlaw commented Dec 15, 2024

cache._clear did not work.

It's really strange. I did some tests:

import numpy as np
import time

Q = np.random.rand(50)
T = np.random.rand(1000)
m = 50

cache._enable()

start = time.time()
distance_profile = stumpy.mass(Q, T)
print(time.time() - start)

print(cache._get_cache())

cache._clear()

print(cache._get_cache())

start = time.time()
distance_profile = stumpy.mass(Q, T)
print(time.time() - start)

Somehow, the final stumpy.mass call is still being cached somewhere (note that I am using time.time as a proxy to whether or not a function is being cached and it may not be a good idea). Initially, I thought that it might because I was using a Jupyter notebook so I moved it into a test.py file and it was still being cached somewhere. So, things are not so simple.

Even if I try to force numba to flush the cache for each njit function via:

def _clear():
    """
    Clear numba cache

    Parameters
    ----------
    None

    Returns
    -------
    None
    """
    warnings.warn(CACHE_WARNING)
    site_pkg_dir = site.getsitepackages()[0]
    numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
    [f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
    njit_funcs = get_njit_funcs()
    for module_name, func_name in njit_funcs:
        module = importlib.import_module(f".{module_name}", package="stumpy")
        func = getattr(module, func_name)
        func._cache.flush()

Note that func._cache.flush() does change the behavior and we can see that the .nbi/.nbc numba cache files are being regenerated in the NUMBA_CACHE_DIR (in site-packages). However, the timing is still "too fast". It's possible that the cache is being cleared but using "time.time" to gauge whether anything is being cached may be wrong?

Let's not worry about cache for now and go the .recompile() route as it is showing promise.

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Dec 16, 2024

@NimaSarajpoor In case it may be helpful, you could possibly use this function to help recompile everything:
print(cache.get_njit_funcs())

Recompiling all njit functions seems to take a negligible time. I think it is safer to just recompile all functions. For instance, in the future, if something is changed, we know that this part will be still okay because we just recompile everything. Need to think if there is a more elegant way to do it but I tried this and it worked:

for (module, njit_func) in cache.get_njit_funcs():
    code = f"from stumpy.{module} import {njit_funcs}; {njit_funcs}.recompile()"
    exec(code)

Going to follow these steps:
(1) Add config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn", "reassoc} to config
(2) Replace hardcoded fastmath flags with config.STUMPY_FASTMATH_FLAGS
(3) Change the value of this config param in tests/test_precision.py::test_snippets, and recompile all njit-decorated functions. Change the param's value back to its default, and recompile all njit-decorated functions again.

We can stop here. Or, we can continue and add the changes that @seanlaw suggested here:
#1048 (comment)

(4) Revise config and refactor (3). As suggested, we can add new function for reseting values (and recompile everything). Similarly, we can add a function that set config vars to new values, and recompile everything.

def _get_config_vars():
    """
    Get config variables.

    Parameters
    ----------
    None

    Returns
    -------
    config_vars : list
        A list of config variables
    """
    config_vars = [
        k for k, _ in globals().items() if k.isupper() and k.startswith("STUMPY")
    ]

    return config_vars


def _set(vars):
    """
    Set config variables to new values.

    Parameters
    ----------
    vars : dict
        A dictionary where key is param name, and value is the desired value.

    Returns
    -------
    None
    """
    config_vars = _get_config_vars()

    # revise config values
    s = set(vars.keys())
    if not s.issubset(config_vars):
        extra_vars = s - set(config_vars)
        msg = f"Found invalid config variables: {extra_vars}"
        raise ValueError(msg)

    # set values
    for param, val in vars.items():
        globals()[param] = val

    # recompile everything
    for module, njit_func in cache.get_njit_funcs():
        code = f"from stumpy.{module} import {njit_func}; {njit_func}.recompile()"
        exec(code)

    return

@seanlaw
Any suggestion for the provided steps above?

@seanlaw
Copy link
Contributor

seanlaw commented Dec 16, 2024

Recompiling all njit functions seems to take a negligible time. I think it is safer to just recompile all functions. For instance, in the future, if something is changed, we know that this part will be still okay because we just recompile everything. Need to think if there is a more elegant way to do it but I tried this and it worked:

The cache._enable() uses:

stumpy/stumpy/cache.py

Lines 68 to 72 in 3165d1c

njit_funcs = get_njit_funcs()
for module_name, func_name in njit_funcs:
module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
func.enable_caching()

So, we should probably reuse/refactor/share some of the code by doing:

warnings.warn(CACHE_WARNING)
    njit_funcs = get_njit_funcs()
    for module_name, func_name in njit_funcs:
        module = importlib.import_module(f".{module_name}", package="stumpy")
        func = getattr(module, func_name)
        func.recompile()

Using exec is usually dangerous and encourages bad actors to hack our code.

Going to follow these steps:
(1) Add config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn", "reassoc} to config
(2) Replace hardcoded fastmath flags with config.STUMPY_FASTMATH_FLAGS
(3) Change the value of this config param in tests/test_precision.py::test_snippets, and recompile all njit-decorated functions. Change the param's value back to its default, and recompile all njit-decorated functions again.

We can stop here. Or, we can continue and add the changes that @seanlaw suggested here:
#1048 (comment)

(4) Revise config and refactor (3). As suggested, we can add new function for reseting values (and recompile everything). Similarly, we can add a function that set config vars to new values, and recompile everything.

Personally, I think that following my previous comment is my preference because it really anchors (and pushes the user) back to a clearly defined set of default values.

More importantly, I think that the .recompile() step should be kept separately from or outside of the config.py file. I actually think that it should go in cache._recompile() but I am open to discussing it further. Normally, in these situations, I like to think about WHAT will the unit test code look like (i.e., how will it change) and how will the user express themselves in their code? In this case, there is the existing unit test code (that doesn't quite work). Then, we need to add the if/else np.isclose checks, followed by changing the config.STUMPY_FASTMATH_FLAGS and recompiling, and then seeing if the tests pass. How would this code look? Would it seem obvious what is happening or why we need to do it? Considering that this is a rare situation, is the code too verbose or overly cumbersome or "just painful enough" to dissuade users from using config.py or cache.py? Once you see the unit test (and how ugly/clean it is), I think you'll better understand what the config.py/cache.py API needs to look like to enable the work. So I would suggest with writing additional parts of the unit test so that it feels "natural" and that will help mould the API. How does that sound?

Recompiling all njit functions seems to take a negligible time.

Hmm, I am really questioning WHY the initial/first compilation takes SO long but recompilation is SO fast?!

@NimaSarajpoor
Copy link
Collaborator Author

Using exec is usually dangerous and encourages bad actors to hack our code.

👍

Once you see the unit test (and how ugly/clean it is), I think you'll better understand what the config.py/cache.py API needs to look like to enable the work. So I would suggest with writing additional parts of the unit test so that it feels "natural" and that will help mould the API. How does that sound?

In fact, I did change it locally and noticed the test function seems to be dirty, and hence created the function _set(vars) in the config.py. I got your point. Let me make some changes and push it here, and then we can move from there as you suggested.

Hmm, I am really questioning WHY the initial/first compilation takes SO long but recompilation is SO fast?!

I am curious to know the answer too! I created a small njit-decorated function func and got the func.metadata() and looked at the pipeline times, but that did not help me find the answer!

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Dec 18, 2024

@seanlaw
I tried a trick provided in this Numba Discourse's comment to take a closer look at the compilation.

import numa
from stumpy import core

oldcomp = numba.core.registry.CPUDispatcher.compile                                                                                                  

def newcomp(*args, **kwargs): 
    print('arges: ', args)
    print('kwargs: ', kwargs) 
    return oldcomp(*args, **kwargs) 
                                                                                                                                                      
numba.core.registry.CPUDispatcher.compile = newcomp


s = 3
QT = 1453258.774722079
μ_Q = -55.35781464461123
σ_Q = 650.912209452633
μ_T = -264.9540642227162
σ_T = 722.0717285148526
isconstant_Q = False
isconstant_T = False


ref = core._calculate_squared_distance(s, QT, μ_Q, σ_Q, μ_T, σ_T, isconstant_Q, isconstant_T)
print('=' * 50)
core._calculate_squared_distance.targetoptions["fastmath"] = {"nsz", "arcp", "contract", "afn"}
print('=' * 50)
core._calculate_squared_distance.recompile()
print('=' * 50)

And this gives me:

arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b2560>), (float32,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b2560>), (float32,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b2560>), (float32,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b3c70>), (float64,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b3c70>), (float64,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b3c70>), (float64,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b3760>), (complex64,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b3760>), (complex64,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b3760>), (complex64,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b0820>), (complex128,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b0820>), (complex128,))
kwargs:  {}
arges:  (CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b0820>), (complex128,))
kwargs:  {}
arges:  (CPUDispatcher(<function _calculate_squared_distance at 0x12a851480>), (int64, float64, float64, float64, float64, float64, bool, bool))
kwargs:  {}
==================================================
arges:  (CPUDispatcher(<function _calculate_squared_distance at 0x12a851480>), (int64, float64, float64, float64, float64, float64, bool, bool))
kwargs:  {}
==================================================

And I tried to go deeper by printing out args[0].__dict__ and I noticed the following
for CPUDispatcher(<function _loggamma_impl.<locals>.<lambda> at 0x12d0b2560>):

file ".../miniconda3/envs/py310/lib/python3.10/site-packages/rocket_fft/special.py line 88

but why do I see rocket_fft here? It is pointing to this line:

https://github.com/styfenschaer/rocket-fft/blob/a9890cd658d4ce6e99b8286e5f70d498adff2a84/rocket_fft/special.py#L88


[Update]
Not sure what was wrong in my env. Created a new conda env. Those extra compilations are now gone. Still see a difference in the running time though.

[Update2]

I checked the metadata and noticed the sum of times reported in the pipeline time of the first compilation is almost the same as the recompile time. We can also see the compiler_lock time. In the first compilation, this value is considerably higher than the sum of times reported in the pipeline time. In the recompile, however, the compiler_lock time is close to the sum of times reported in the pipeline time.

@seanlaw
Copy link
Contributor

seanlaw commented Dec 20, 2024

@NimaSarajpoor Did you need anything from me?

@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw
No. all good. Planning to work on it over the weekend.

Copy link
Collaborator Author

@NimaSarajpoor NimaSarajpoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanlaw
I made some changes to address the recently-raised concerns. Also, we now have test for the fastmath.py module thanks to your suggestion. I added a couple of comments for now to seek your advice.

tests/test_fastmath.py Outdated Show resolved Hide resolved
test.sh Outdated Show resolved Hide resolved
stumpy/fastmath.py Outdated Show resolved Hide resolved
@seanlaw
Copy link
Contributor

seanlaw commented Jan 21, 2025

@NimaSarajpoor Do you think we're ready to merge? Everything looks good to me. After it has merged, I can trigger the "other" workflows (https://github.com/stumpy-dev/automate/actions) to see if things fail/pass

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Jan 21, 2025

@seanlaw
I took a quick Look. I think all is good.

Just one question: Do you think we should add test to check for fastmath=False, fastmath=True, fastmath={? Because they should all be in the fastmath=config.<CONFIG_VAR> format.

@seanlaw
Copy link
Contributor

seanlaw commented Jan 21, 2025

Just one question: Do you think we should add test to check for fastmath=False, fastmath=True, fastmath={? Because they should all be in the fastmath=config.<CONFIG_VAR> format.

Yes, let's do it! Maybe we should should include it inside of the call to fastmath.py --check stumpy? So, have that check EVERYTHING

Copy link
Collaborator Author

@NimaSarajpoor NimaSarajpoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanlaw
I added a new function to ./fastmath.py. I left some comments. Could you please take a look and let me know what you think?

fastmath.py Outdated Show resolved Hide resolved
fastmath.py Outdated Show resolved Hide resolved
fastmath.py Show resolved Hide resolved
fastmath.py Outdated Show resolved Hide resolved
fastmath.py Outdated Show resolved Hide resolved
fastmath.py Outdated Show resolved Hide resolved
fastmath.py Outdated Show resolved Hide resolved
@seanlaw
Copy link
Contributor

seanlaw commented Jan 22, 2025

I added a new function to ./fastmath.py. I left some comments. Could you please take a look and let me know what you think?

@NimaSarajpoor After staring at the code and trying to convince myself that "it's okay", I came to the conclusion that it's way too complicated. I tried to understand the approach but it's too convoluted (I came to that result when you commented about using continue and it was an indication that it was the wrong approach). I say all of this because most of the work can be accomplished in a few lines of shell script code:

if [[ `grep -n fastmath stumpy/*py | grep -vE 'fastmath=config' | wc -l` -gt "0" ]]; then
    grep -n fastmath stumpy/*py | grep -vE 'fastmath=config'
    echo "Found one or more `@njit()` functions with a hardcoded `fastmath` flag."
    exit 1
fi

For 5 lines of code, I think this is "good enough" as it:

  1. Returns the file/line number
  2. Errors out
  3. It's fairly easy to maintain and I can keep it all in my head without much mental effort

Apologies if I steered you in the wrong way and made it feel like adding it to fastmath.py was the only option. In my head, the shell script approach was simple and so I (foolishly) expected the Python code to be no more than 10 lines but, alas, I was wrong. Sorry about that!

@NimaSarajpoor
Copy link
Collaborator Author

I came to the conclusion that it's way too complicated. I tried to understand the approach but it's too convoluted (I came to that result when you commented about using continue and it was an indication that it was the wrong approach)

Thanks for the explanation! As I was adding more nested for-loops/if-blocks, I was smelling it but then tried to avoid it to some extent 😅

can be accomplished in a few lines of shell script code

Right.... The regex pattern is also simple and can be read easily.

Apologies if I steered you in the wrong way and made it feel like adding it to fastmath.py was the only option

All good, no apology needed! It actually helps me improve my thinking process!

@NimaSarajpoor NimaSarajpoor force-pushed the investigate_precision_failure branch from dd5cc64 to 2369e33 Compare January 23, 2025 03:04
test.sh Show resolved Hide resolved
@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw
I addressed the concerns. I left one comment for your consideration. If you are okay with that and other changes, please feel free to merge the PR.

@seanlaw seanlaw merged commit 80dc3e9 into TDAmeritrade:main Jan 23, 2025
27 checks passed
@seanlaw
Copy link
Contributor

seanlaw commented Jan 23, 2025

@NimaSarajpoor Thanks again for your time and effort here!

This was referenced Jan 23, 2025
@seanlaw
Copy link
Contributor

seanlaw commented Jan 25, 2025

@NimaSarajpoor commit db1958a now enables coverage testing of cache.py. It's nothing special but it works as the functions are actually called and "attempted". stumpy/fastmath/tests/test_fastmath.py are next!

@seanlaw
Copy link
Contributor

seanlaw commented Jan 26, 2025

Okay, we now have proper coverage for stumpy/fastmath.py and `tests.test_fastmath,py in commit 92a2b41

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants