Skip to content

Commit e98c538

Browse files
committed
fix: broadcast errors using lazy n_samples and da.where in r2_score
- Add _safe_rechunk helper for safe rechunking with error handling. - Set _n_samples using X.shape[0] in fit() to avoid eager evaluation from len(X). - Use n_samples to derive the rechunking block size so test data aligns one-to-one with training blocks, preventing broadcast mismatches. - Update _predict()/_collect_probas() accordingly. - Refactor r2_score() to use da.where() for correct broadcasting. - Resolves "cannot broadcast shape (nan,) to shape (nan,)" errors.
1 parent 69d168b commit e98c538

File tree

4 files changed

+97
-12
lines changed

4 files changed

+97
-12
lines changed

dask_ml/ensemble/_blockwise.py

+32
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
from ..utils import check_array, is_frame_base
99

1010

11+
def _safe_rechunk(arr, rechunk_dict, error_context=""):
12+
"""Helper function to safely rechunk arrays with proper error handling."""
13+
try:
14+
return arr.rechunk(rechunk_dict)
15+
except Exception as e:
16+
msg = (
17+
"Failed to rechunk array"
18+
f"{': ' + error_context if error_context else ''}: {e}"
19+
)
20+
raise ValueError(msg) from e
21+
22+
1123
class BlockwiseBase(sklearn.base.BaseEstimator):
1224
def __init__(self, estimator):
1325
self.estimator = estimator
@@ -22,6 +34,11 @@ def _check_array(self, X):
2234

2335
def fit(self, X, y, **kwargs):
2436
X = self._check_array(X)
37+
try:
38+
self._n_samples = X.shape[0]
39+
except Exception:
40+
self._n_samples = None
41+
2542
estimatord = dask.delayed(self.estimator)
2643

2744
Xs = X.to_delayed()
@@ -45,6 +62,7 @@ def fit(self, X, y, **kwargs):
4562
]
4663
results = list(dask.compute(*results))
4764
self.estimators_ = results
65+
return self
4866

4967
def _predict(self, X):
5068
"""Collect results from many predict calls"""
@@ -54,6 +72,13 @@ def _predict(self, X):
5472
dtype = "float64"
5573

5674
if isinstance(X, da.Array):
75+
if hasattr(self, "_n_samples") and self._n_samples is not None:
76+
desired = len(self.estimators_)
77+
if X.numblocks[0] != desired:
78+
block_size = max(1, self._n_samples // desired)
79+
X = _safe_rechunk(
80+
X, {0: block_size}, "to match estimator partitioning"
81+
)
5782
chunks = (X.chunks[0], len(self.estimators_))
5883
combined = X.map_blocks(
5984
_predict_stack,
@@ -174,6 +199,13 @@ def _predict_proba(self, X):
174199

175200
def _collect_probas(self, X):
176201
if isinstance(X, da.Array):
202+
if hasattr(self, "_n_samples") and self._n_samples is not None:
203+
desired = len(self.estimators_)
204+
if X.numblocks[0] != desired:
205+
block_size = max(1, self._n_samples // desired)
206+
X = _safe_rechunk(
207+
X, {0: block_size}, "to match estimator partitioning"
208+
)
177209
chunks = (len(self.estimators_), X.chunks[0], len(self.classes_))
178210
meta = np.array([], dtype="float64")
179211
# (n_estimators, len(X), n_classes)

dask_ml/metrics/regression.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,13 @@ def r2_score(
162162
numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8")
163163
denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype="f8")
164164

165-
nonzero_denominator = denominator != 0
166-
nonzero_numerator = numerator != 0
167-
valid_score = nonzero_denominator & nonzero_numerator
168-
output_chunks = getattr(y_true, "chunks", [None, None])[1]
169-
output_scores = da.ones([y_true.shape[1]], chunks=output_chunks)
170-
with np.errstate(all="ignore"):
171-
output_scores[valid_score] = 1 - (
172-
numerator[valid_score] / denominator[valid_score]
173-
)
174-
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0
175-
176-
result = output_scores.mean(axis=0)
165+
score = da.where(
166+
numerator == 0,
167+
1.0,
168+
da.where(denominator != 0, 1 - numerator / denominator, 0.0),
169+
)
170+
171+
result = score.mean(axis=0)
177172
if compute:
178173
result = result.compute()
179174
return result

tests/ensemble/test_blockwise.py

+42
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,27 @@ def test_no_classes_raises(self):
186186

187187

188188
class TestBlockwiseVotingRegressor:
189+
def test_no_unnecessary_computation_in_fit(self, monkeypatch):
190+
X, y = dask_ml.datasets.make_regression(n_features=20, chunks=25)
191+
compute_called = False
192+
original_compute = X.compute
193+
194+
def spy_compute(*args, **kwargs):
195+
nonlocal compute_called
196+
compute_called = True
197+
return original_compute(*args, **kwargs)
198+
199+
monkeypatch.setattr(X, "compute", spy_compute)
200+
201+
est = dask_ml.ensemble.BlockwiseVotingRegressor(
202+
sklearn.linear_model.LinearRegression(),
203+
)
204+
est.fit(X, y)
205+
# Ensure that X.compute() was never invoked during fitting.
206+
assert compute_called is False
207+
# Verify that _n_samples was set using lazy metadata.
208+
assert est._n_samples == X.shape[0]
209+
189210
def test_fit_array(self):
190211
X, y = dask_ml.datasets.make_regression(n_features=20, chunks=25)
191212
est = dask_ml.ensemble.BlockwiseVotingRegressor(
@@ -240,3 +261,24 @@ def test_fit_frame(self):
240261
# TODO: r2_score raising for ndarray
241262
# score2 = est.score(X3, y3)
242263
# assert score == score2
264+
265+
def test_predict_with_different_chunks(self):
266+
X, y = dask_ml.datasets.make_regression(n_features=20, chunks=25)
267+
est = dask_ml.ensemble.BlockwiseVotingRegressor(
268+
sklearn.linear_model.LinearRegression(),
269+
)
270+
est.fit(X, y)
271+
272+
X_test, y_test = dask_ml.datasets.make_regression(n_features=20, chunks=20)
273+
result = est.predict(X_test)
274+
assert result.dtype == np.dtype("float64")
275+
assert result.shape == y_test.shape
276+
# Prediction is rechunked to have one block per estimator.
277+
assert result.numblocks[0] == len(est.estimators_)
278+
279+
score = est.score(X_test, y_test)
280+
assert isinstance(score, float)
281+
282+
X_test_np, y_test_np = dask.compute(X_test, y_test)
283+
result_np = est.predict(X_test_np)
284+
da.utils.assert_eq(result, result_np)

tests/metrics/test_regression.py

+16
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,19 @@ def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs):
116116

117117
with pytest.raises((NotImplementedError, ValueError), match=error_msg):
118118
_ = m1(a, b, multioutput=weights)
119+
120+
121+
def test_r2_score_with_different_chunk_patterns():
122+
"""Test r2_score with different chunking configurations."""
123+
# Create arrays with compatible but different chunk patterns
124+
a = da.random.uniform(size=(100,), chunks=25) # 4 chunks
125+
b = da.random.uniform(size=(100,), chunks=20) # 5 chunks
126+
result = dask_ml.metrics.r2_score(a, b)
127+
assert isinstance(result, float)
128+
# Create arrays with different chunk patterns
129+
a_multi = da.random.uniform(size=(100, 3), chunks=(25, 3)) # 4 chunks
130+
b_multi = da.random.uniform(size=(100, 3), chunks=(20, 3)) # 5 chunks
131+
result_multi = dask_ml.metrics.r2_score(
132+
a_multi, b_multi, multioutput="uniform_average"
133+
)
134+
assert isinstance(result_multi, float)

0 commit comments

Comments
 (0)