-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Array: correct number of outputs in apply_gufunc
#7669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Array: correct number of outputs in apply_gufunc
#7669
Conversation
e058bf2 to
7a6b68f
Compare
| ret = apply_gufunc(foo, "()->()", 1.0, output_dtypes=float, bar=2) | ||
| assert_eq(ret, np.array(1.0, dtype=float)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was failing because the array chunks were float64, while meta expected float32 ("f"). I'm not sure why it ever passed before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the original was specifying output_dtypes="f" and the assertion fails for that, it probably means somewhere the value is being cast to float64. Seems like a sneaky regression to me.
Maybe it would be worth checking if specifying other output_dtypes would also cast the return value to float64.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pentschev I dug into this and found a number of concerning things:
- On
main, despite sayingoutput_dtypes="f"(float32),ret.dtypeis float64! As I mentioned in the description,output_typesis currently actually ignored unless inferring the dtype fromfuncfails. Since calling thefoofunction with a float64 succeeds (the1.0input becomes NumPy float64),apply_gufuncreturns a float64 array, even though you told it to do float32. - Then,
assert_eqdoesn't actually check dtypes?! Turns out there is actually nothing in theassert_eqlogic to check whether the dtypes of the two computed arrays match—just lots of logic to check that each one's meta matches its result. So the fact thatretand the expectednp.array(1.0, dtype="f")have different dtypes is ignored, and the test passes. @jrbourbeau is this well-known? Should I open an issue about it? Seems like misleading behavior.
Since my changes now respect output_dtypes, the resulting array's meta (float32) does not match its computed dtype (float64, from the 1.0 input value), and assert_eq complained about that meta mismatch until I made the change now in this PR.
So basically, the old test was wrong: the output dtype in this case is in fact float64, so passing output_dtype="f" should have resulted in an error from assert_eq about a meta mismatch, except that output_dtype was ignored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- On
main, despite sayingoutput_dtypes="f"(float32),ret.dtypeis float64! As I mentioned in the description,output_typesis currently actually ignored unless inferring the dtype fromfuncfails. Since calling thefoofunction with a float64 succeeds (the1.0input becomes NumPy float64),apply_gufuncreturns a float64 array, even though you told it to do float32.
To be honest, I'm not certain what's the exact behavior we want/users expect here. But if output_dtypes is specified, shouldn't apply_gufunc cast to the output_dtypes? Note that we do that in other places, such as partial_reduce. That would be, of course, only for the case where meta is computed automatically, but the user wants to enforce the output_dtypes nevertheless.
- Then,
assert_eqdoesn't actually check dtypes?! Turns out there is actually nothing in theassert_eqlogic to check whether the dtypes of the two computed arrays match—just lots of logic to check that each one's meta matches its result. So the fact thatretand the expectednp.array(1.0, dtype="f")have different dtypes is ignored, and the test passes. @jrbourbeau is this well-known? Should I open an issue about it? Seems like misleading behavior.
Interesting catch, and I would say we should indeed open an issue. Even though I don't recall if there's a reason why we don't do it, it's worth at least discussing that, in case that is in fact a bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I aggree with your assessment that assert_eq doesn't check computed dtype against meta dtype. That'd be a nice thing to add :)
| ): | ||
| return x | ||
|
|
||
| if isinstance(x, list) or isinstance(x, tuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this before the not hasattr(x, "shape") or not hasattr(x, "dtype"), because otherwise it's a no-op: lists and tuples don't have shape/dtype, so we always would have returned from the function before reaching this check.
| # min/max functions have no identity, just use the same input type when there's only one | ||
| if len( | ||
| args_meta | ||
| ) == 1 and "zero-size array to reduction operation" in str(e): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As noted in #7668. @pentschev, curious about your thoughts here. Personally, I don't think we should have this case at all; I don't think it's a safe assumption to make in generality for any user-defined function. But I do see the convenience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your assessment, t's not a safe assumption for user-defined functions indeed and I would as well prefer not to need this. However, without this condition attempting to compute meta for a 0-D downstream array (e.g., CuPy) will fall into the general except Exception case and return None, which will cause meta to be a NumPy array, and thus break those downstream use cases. The problem is we have no known alternative, this is why we check the specific exception message, for any other corner cases we would have to extend this function to handle them properly, but user-defined functions should guarantee proper handling of 0-D arrays to avoid such issues.
pentschev
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pinging me here @gjoseph92 . Overall this looks good to me, it seems like a more feature-complete manner of handling apply_gufunc. I've left a few comments, and assuming I haven't missed anything about that, I think test_apply_gufunc_pass_additional_kwargs is a potentially concerning test, so it may be worth checking whether some sneaky dtype casting is happening silently, such cases can happen and can be difficult to find, but perhaps it's just me being overly cautious.
I've also verified that this PR doesn't break the CuPy tests we have (which aren't covered by CI yet), so this feels like a good solution. Thanks for working on it!
| ret = apply_gufunc(foo, "()->()", 1.0, output_dtypes=float, bar=2) | ||
| assert_eq(ret, np.array(1.0, dtype=float)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the original was specifying output_dtypes="f" and the assertion fails for that, it probably means somewhere the value is being cast to float64. Seems like a sneaky regression to me.
Maybe it would be worth checking if specifying other output_dtypes would also cast the return value to float64.
| # min/max functions have no identity, just use the same input type when there's only one | ||
| if len( | ||
| args_meta | ||
| ) == 1 and "zero-size array to reduction operation" in str(e): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your assessment, t's not a safe assumption for user-defined functions indeed and I would as well prefer not to need this. However, without this condition attempting to compute meta for a 0-D downstream array (e.g., CuPy) will fall into the general except Exception case and return None, which will cause meta to be a NumPy array, and thus break those downstream use cases. The problem is we have no known alternative, this is why we check the specific exception message, for any other corner cases we would have to extend this function to handle them properly, but user-defined functions should guarantee proper handling of 0-D arrays to avoid such issues.
|
Is this ready to be merged? Also I am planning on merging #6863 today as well - I think the two PRs address different issues. |
|
I'm fine with that @jsignell , I think the only unanswered portion here was the first part of #7669 (comment) . It's only a matter of how we want/should treat |
|
I think the question is whether It's basically a question of whether However, I have very little preference, this is working as is, and based on the docstring for |
|
Unless I am missing something you currently don't have a test for |
|
@mathause good catch, there's no test for raising that error. |
Refactor
apply_gufuncto compute meta itself, rather than relying onblockwiseto do it. Having just one code path and consolidating everything onmetamakes things easier to reason about, and happens to fix #7668. Additionally before,funcwould be called even ifoutput_dtypeswas given (contrary to the docstring); now it is no longer.This also makes it an explicit error to pass both
output_dtypesandmeta—what you want in that case is rather ambiguous.Also, use a modern function definition instead of popping things out of
kwargsto support py2.apply_gufuncbugs with meta inference #7668black dask/flake8 dask/isort dask