Skip to content

fix: prevent broadcasting errors in r2_score using da.where() #1013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 10, 2025

Conversation

wietzesuijker
Copy link
Contributor

@wietzesuijker wietzesuijker commented Mar 3, 2025

Closes #1012

First PR here. Curious to hear your feedback.

Problem
After updating to Dask 2025.2.0, tests fail with a ValueError due to changes in chunk size handling.

Solution
Refactor r2_score() to use da.where() for correct broadcasting.

Testing
Test added to ensure r2_score() works correctly with arrays that have different chunk configurations.

@wietzesuijker wietzesuijker force-pushed the fix/broadcast-shape-nan branch from d7f7b86 to 17d3a02 Compare March 23, 2025 15:40
@wietzesuijker wietzesuijker changed the title fix(ensemble, metrics): compute chunk sizes and refactor r2_score wit… prevent broadcasting errors with unknown chunk sizes Mar 23, 2025
@TomAugspurger
Copy link
Member

Thanks.

I'm not entirely sure what the best action is, but I think we ought to avoid anything that triggers computation unnecessarily, including len.

Can you say a bit more about getting n_samples is needed in blockwise?

@wietzesuijker wietzesuijker force-pushed the fix/broadcast-shape-nan branch from 17d3a02 to e98c538 Compare March 29, 2025 21:25
@wietzesuijker wietzesuijker changed the title prevent broadcasting errors with unknown chunk sizes fix: broadcast errors using lazy n_samples and da.where in r2_score Mar 29, 2025
@wietzesuijker
Copy link
Contributor Author

Thanks @TomAugspurger. n_samples (now obtained via X.shape[0]) lets us determine the rechunking size without forcing computation. It splits the test data into one block per trained estimator, ensuring alignment and preventing broadcast errors. Combined with the da.where() update in r2_score, these changes maintain laziness and correct behavior with mismatched chunks.

@TomAugspurger
Copy link
Member

I'm probably missing something, but why do we care that the size of the test dataset matches the size of the training dataset (_n_samples)? I'd expect us to just care that the number of samples in X_train and y_train to match, and separately that the number of samples in X_test, y_test match.

@wietzesuijker
Copy link
Contributor Author

why do we care that the size of the test dataset matches the size of the training dataset (_n_samples)?

The goal is not for the test dataset to match the training dataset's overall size. The focus is ensuring each estimator, trained on a specific data block, receives a matching block from the test set. X.shape[0] is used (as n_samples) to compute the optimal test data block size, dividing the test set into blocks equal to the number of estimators. This aligns predictions and prevents broadcast errors, regardless of training and test dataset sizes.

Resolves "cannot broadcast shape (nan,) to shape (nan,)" errors.
@wietzesuijker wietzesuijker force-pushed the fix/broadcast-shape-nan branch from e98c538 to 9c984a0 Compare May 5, 2025 14:42
@wietzesuijker wietzesuijker changed the title fix: broadcast errors using lazy n_samples and da.where in r2_score fix: prevent broadcasting errors in r2_score using da.where() May 5, 2025
@wietzesuijker
Copy link
Contributor Author

wietzesuijker commented May 5, 2025

@TomAugspurger I limited the PR to changes in dask_ml/metrics/regression.py. Merging this would unblock my use case and allow me to update dask. Thanks! :). (previous state)

@wietzesuijker
Copy link
Contributor Author

Thanks for triggering the tests, Tom. Is there anything I can/should do to fix the failing runs? The errors seem unrelated and similar to other recent runs e.g. https://github.com/dask/dask-ml/actions/runs/14849752234/job/41691028477.

@TomAugspurger
Copy link
Member

I'll take a look in the next few days.

@TomAugspurger
Copy link
Member

Thanks @wietzesuijker. There's still one error in tests/preprocessing/test_data.py::TestQuantileTransformer::test_fit_transform_frame but the rest from #1012 are fixed. I'll fix that in another PR.

@TomAugspurger TomAugspurger merged commit 6fdd1f4 into dask:main May 10, 2025
4 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Tests failing with ValueError: cannot broadcast shape (nan,) to shape (nan,)
2 participants