-
Notifications
You must be signed in to change notification settings - Fork 23.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make addmm meta kernel consistent with mm #84960
Make addmm meta kernel consistent with mm #84960
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84960
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 03b8f4f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
No, I don't understand this change. Here's the schema for addmm
self is the first arg. Reordering it in the macro doesn't make sense? |
oh, I only reordered it to match the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, better. Can we get a test?
@@ -55,7 +55,7 @@ namespace meta { | |||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); \ | |||
\ | |||
auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self); \ | |||
set_output_raw_strided(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also check that mat1, mat2 and self dtypes are the same? The actual implementation will error out on the different types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing as this is an implementation specific detail, and not related to shape inference, is the meta kernel the best place to put this check? I mean, it already errors out in practice which means that the check already exists in the correct places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the point of the meta kernel is not just for shape inference, it is also used in other contexts such as to run through a sequence of operations and produce the same error messages you'd see in eager without actually running the ops in eager. so the practice is to handle all the shape/dtype inference as well as all the error checking in the meta kernel in a way consistent to eager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I see. Okay, I can add that check then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait, I can add a check to see if the dtypes of self
and mat2
match as they are the bias
and weight
respectively. However, checking if the input
dtype matches the parameters will fail in the case that I originally made the issue for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cuda addmm behavior is the same, all the dtypes have to match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The case I'm referring to isn't in any execution mode. Its for custom mixed precision on custom vendor hardware that is only traced via lazy tensors. Hence, the case
torch.float32 torch.float16 torch.float32
should be allowed here and should output type torch.float16
to match the input. The actual casting is done further down our stack, not at the top level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, then I think I would probably claim your custom AMP is wrong. The two ways I can think of solving this: (1) explicitly generate the casts and fuse then away in your backend compiler, (2) produce an alternative matmul that is type promoting and have the backend target that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, we've decided to do with solution (1). In which case, the changes in this PR aren't quite necessary anymore. Do we still want to merge this in (including the added dtype checks)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dtype checks would be appreciated!
c0d85b2
to
4104516
Compare
@@ -48,14 +48,16 @@ namespace detail { | |||
namespace meta { | |||
|
|||
#define ADDMM_META() \ | |||
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype"); \ | |||
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype"); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to add an ErrorInput to the addmm OpInfo test for this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you mean by ErrorInput
and I can't find any examples in the test/
directory. Can you please provide an example of what you mean?
test/lazy/test_meta_kernel.py
Outdated
|
||
try: | ||
out_nobias = fc_nobias(input) | ||
self.assertTrue(False) # Should never reach here as the above line should throw exception |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More idiomatic is assertRaisesRegex/assertRaises
def test_addmm_invalid_dtype(self): | ||
"""Tests that the addmm meta kernel returns the correct output type""" | ||
input = torch.ones(2, 2, dtype=torch.float16).to("lazy") | ||
self.assertTrue(input.dtype == torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more idiomatic: assertEqual
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here. |
Hey @antoniojkim. |
|
||
fc_nobias = torch.nn.Linear(2, 2, bias=False, dtype=float32).to("lazy") | ||
|
||
with self.assertRaises(Exception): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be RuntimeError?
Change the names of the parameters in the `addmm` meta kernel to be more consistent with `mm`. Functionally, the only difference in behaviour should be that `addmm` meta kernel gets its options from the `input` tensor instead of from the `bias` parameter. Fixes #84930 CC: @ezyang @ngimel @wconstab @ke1337 @glebk-cerebras Pull Request resolved: #84960 Approved by: https://github.com/ezyang
Change the names of the parameters in the
addmm
meta kernel to be more consistent withmm
. Functionally, the only difference in behaviour should be thataddmm
meta kernel gets its options from theinput
tensor instead of from thebias
parameter.Fixes #84930
CC: @ezyang @ngimel @wconstab @ke1337 @glebk-cerebras