Skip to content

Commit 9c984a0

Browse files
committed
fix: r2_score() to use da.where() for broadcasting
Resolves "cannot broadcast shape (nan,) to shape (nan,)" errors.
1 parent 69d168b commit 9c984a0

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

dask_ml/metrics/regression.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,13 @@ def r2_score(
162162
numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8")
163163
denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype="f8")
164164

165-
nonzero_denominator = denominator != 0
166-
nonzero_numerator = numerator != 0
167-
valid_score = nonzero_denominator & nonzero_numerator
168-
output_chunks = getattr(y_true, "chunks", [None, None])[1]
169-
output_scores = da.ones([y_true.shape[1]], chunks=output_chunks)
170-
with np.errstate(all="ignore"):
171-
output_scores[valid_score] = 1 - (
172-
numerator[valid_score] / denominator[valid_score]
173-
)
174-
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0
175-
176-
result = output_scores.mean(axis=0)
165+
score = da.where(
166+
numerator == 0,
167+
1.0,
168+
da.where(denominator != 0, 1 - numerator / denominator, 0.0),
169+
)
170+
171+
result = score.mean(axis=0)
177172
if compute:
178173
result = result.compute()
179174
return result

tests/metrics/test_regression.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,19 @@ def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs):
116116

117117
with pytest.raises((NotImplementedError, ValueError), match=error_msg):
118118
_ = m1(a, b, multioutput=weights)
119+
120+
121+
def test_r2_score_with_different_chunk_patterns():
122+
"""Test r2_score with different chunking configurations."""
123+
# Create arrays with compatible but different chunk patterns
124+
a = da.random.uniform(size=(100,), chunks=25) # 4 chunks
125+
b = da.random.uniform(size=(100,), chunks=20) # 5 chunks
126+
result = dask_ml.metrics.r2_score(a, b)
127+
assert isinstance(result, float)
128+
# Create arrays with different chunk patterns
129+
a_multi = da.random.uniform(size=(100, 3), chunks=(25, 3)) # 4 chunks
130+
b_multi = da.random.uniform(size=(100, 3), chunks=(20, 3)) # 5 chunks
131+
result_multi = dask_ml.metrics.r2_score(
132+
a_multi, b_multi, multioutput="uniform_average"
133+
)
134+
assert isinstance(result_multi, float)

0 commit comments

Comments
 (0)