Skip to content

Fix: don't sample elements with weight 0#983

Open
graidl wants to merge 13 commits intoJuliaStats:masterfrom
graidl:fix-weighted-sampling
Open

Fix: don't sample elements with weight 0#983
graidl wants to merge 13 commits intoJuliaStats:masterfrom
graidl:fix-weighted-sampling

Conversation

@graidl
Copy link

@graidl graidl commented Dec 15, 2025

This PR fixes an issue in sample(::AbstractRNG, ::AbstractWeights) as described in #982:
When using floating point weights, a last element could have been chosen even if it has weight zero.
This could have happend as sum(wv) is in general not exactly the value obtained by sequentially adding up all individual weights in wv due to numerical imprecisions.
In this PR this corner case is catched and the index of the last non-zero weight is returned.

@nalimilan
Copy link
Member

Good catch. Could you add tests?

@graidl
Copy link
Author

graidl commented Dec 16, 2025

Concerning tests: There are already several tests for weighted sampling, and in particular if we keep the latter solution in which just sum(wv) is replaced, I don't see what additional tests would make sense.

src/sampling.jl Outdated
1 == firstindex(wv) ||
throw(ArgumentError("non 1-based arrays are not supported"))
wsum = sum(wv)
wsum = foldl(+, wv) # instead of sum(wv) for avoiding numerical discrepancies with cw
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sum(wv) simply retrieves the value that is stored in the AbstractWeights objects. The problem with this solution is that it would go through one additional time, making the function (more than) twice slower AFAICT.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! I didn't know that. Then I suggest to stay with the first version, and now also included a test covering the former issue.

@nalimilan
Copy link
Member

If we decide on the version with an if, we need at least one test that goes through the new code under the if.

@nalimilan
Copy link
Member

Thanks. Unfortunately, according to Codecov the new branch isn't taken.

@graidl
Copy link
Author

graidl commented Dec 18, 2025

Indeed Codecov does not take the new branch, thanks for pointing this out. This surprised me, as this branch is definitely taken when running the code in a normal Julia session with default parameters.

A more detailed analysis turned out the following:
Simply calculating the sum of the following Float32-Vector yields different results when run in a normal Julia session and in the test environment:

w = Float32[0.0437019, 0.04302464, 0.039748967,  0.040406376, 0.042578973, 
    0.040906563, 0.039586294, 0.04302464, 0.042357873, 0.04302464, 0.039262936, 
    0.040406376, 0.040406376, 0.041919112, 0.041484896, 0.04057242, 0.0]
@info sum w

This yields 0.66241294 when performed in some test, while 0.662413 is obtained in the REPL, and this makes all the difference!

Therefore, I tried to find a seed value in the test environment that triggers the issue and therefore covers the new branch. However, despite trying a huge number of possible values, I failed.
Thus, I can only conclude that the issue does not appear in the test environment because the sum calculation is there done in a different way.

I also tried to find out why the sum calculation actually is different in the test environment than in the REPL but failed to find the reason. The parameters with which Julia was started as obtained by Base.JLOptions() are pretty much the same, in particular no fast_math is used and the same optimization level is the same.

I am stuck now. If it really is that important to cover the new branch, I would need specific guidance on how to reproduce the behavior of a normal Julia session in the test environment in respect to the behavior of the sum calculation over a Float32 vector.

@nalimilan
Copy link
Member

Ah that's weird. FWIW I also get 0.66241294 here. Have you tried running Julia with --check-bounds=yes? This may disable some SIMD instructions, and if your CPU is newer than mine that may explain the difference.

@graidl
Copy link
Author

graidl commented Dec 18, 2025

It seems you are perfectly right! Starting Julia with --check-bounds=yes also gives 0.66241294, and the CPU I'm working on is a relatively new AMD EPYC 9274F that supports SIMD.
I also tried on an old server without SIMD support and without --check-bounds=yes and got there 0.66241294.

On the other hand, a different rather new PC with a AMD Ryzen 9 9950X supporting SIMD started without --check-bounds=yes gives 0.662413 again.

So, what shall we do? I guess it is hardly possible to enforce the usage of SIMD - if it is supported at all on the used test server when running tests?

@nalimilan
Copy link
Member

There are lots of different SIMD instruction sets. The best we can do is prepare tests for the different values you encountered, and maybe print a warning when the sum isn't equal to one of these values so that in the future we are reminded to adjust the test.

@graidl
Copy link
Author

graidl commented Dec 18, 2025

Ok, I added a corresponding check and warning as well as a comment explaining the situation.

@nalimilan
Copy link
Member

Thanks. But this still doesn't cover the lines on GitHub CI AFAICT. Would you be able to find values which reproduce the problem there (with 0.66241294 IIUC)?

test/sampling.jl Outdated
# Without SIMD support, sum(w) == 0.66241294f0 and this test cannot check the
# resolution of the issue.
if sum(w) ∉ (0.662413f0, 0.66241294f0)
@warn "So far unrecognized value for sum(w) encountered."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@warn "So far unrecognized value for sum(w) encountered."
@warn "So far unrecognized value for sum(w) encountered. " *
"Please update test so that it continues to cover special code path."

test/runtests.jl Outdated
Comment on lines 7 to 21
# "weights",
# "moments",
# "scalarstats",
# "deviation",
# "cov",
# "counts",
# "ranking",
# "empirical",
# "hist",
# "rankcorr",
# "signalcorr",
# "misc",
# "pairwise",
# "reliability",
# "robust",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# "weights",
# "moments",
# "scalarstats",
# "deviation",
# "cov",
# "counts",
# "ranking",
# "empirical",
# "hist",
# "rankcorr",
# "signalcorr",
# "misc",
# "pairwise",
# "reliability",
# "robust",
"weights",
"moments",
"scalarstats",
"deviation",
"cov",
"counts",
"ranking",
"empirical",
"hist",
"rankcorr",
"signalcorr",
"misc",
"pairwise",
"reliability",
"robust",

@graidl
Copy link
Author

graidl commented Dec 18, 2025

No, as mentioned I was not able to find any case where the issue is triggered and thus the branch handling it is executed in the case without SIMD support, despite I checked a huge number of cases (i.e. seed values). It seems that without SIMD, the issue does not manifest.

@nalimilan
Copy link
Member

Ah got it. That makes sense actually. Have you tried with larger inputs? sum uses pairwise summation with blocks of 16 values so below that it's just equivalent to a simple loop in the absence of SIMD.

@graidl
Copy link
Author

graidl commented Dec 18, 2025

Will try with larger inputs! And sorry for accidentally committing runtests.jl.

@devmotion
Copy link
Member

The current tests seem too complex and too brittle to me. Can't we just add a very simple artificial test for this branch with a weight vector with incorrectly large predefined sum (you can provide the sum when you construct the weight vector), such that after summing through all weights we're below rand() * very_large_incorrect_sum, so we hit the new branch?

@nalimilan
Copy link
Member

Ah yes good idea!

@graidl
Copy link
Author

graidl commented Dec 19, 2025

Makes sense! I adapted the test in this direction. Still, I wanted to keep the scenario that actually triggers the issue when SIMD support is available, and additionally, there is now also a trivial test case with integer weights.

wsample(a::AbstractArray, w::AbstractVector{<:Real}, dims::Dims;
replace::Bool=true, ordered::Bool=false) =
wsample(default_rng(), a, w, dims; replace=replace, ordered=ordered)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Comment on lines +305 to +307
w = Float32[0.0437019, 0.04302464, 0.039748967, 0.040406376, 0.042578973,
0.040906563, 0.039586294, 0.04302464, 0.042357873, 0.04302464, 0.039262936,
0.040406376, 0.040406376, 0.041919112, 0.041484896, 0.04057242, 0.0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
w = Float32[0.0437019, 0.04302464, 0.039748967, 0.040406376, 0.042578973,
0.040906563, 0.039586294, 0.04302464, 0.042357873, 0.04302464, 0.039262936,
0.040406376, 0.040406376, 0.041919112, 0.041484896, 0.04057242, 0.0]
w = Float32[0.0437019, 0.04302464, 0.039748967, 0.040406376, 0.042578973,
0.040906563, 0.039586294, 0.04302464, 0.042357873, 0.04302464, 0.039262936,
0.040406376, 0.040406376, 0.041919112, 0.041484896, 0.04057242, 0.0]

rng = StableRNG(889858990530)
s = sample(rng, Weights(w, 0.662413f0))
@test s == length(w) - 1
@test sample(rng, Weights([1, 2, 0, 0], 10000)) == 2 # another more trivial test
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test sample(rng, Weights([1, 2, 0, 0], 10000)) == 2 # another more trivial test
# Artificial test with provided sum greater than actual sum
@test sample(rng, Weights([1, 2, 0, 0], 10000)) == 2

@nalimilan nalimilan requested a review from devmotion January 5, 2026 21:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants