-
Notifications
You must be signed in to change notification settings - Fork 280
Initial prototype of differentiable _scaled_grouped_mm function #1969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1969
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 300db8b with merge base 923242e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
7244136
to
a761549
Compare
7afbe08
to
8d15a8a
Compare
8268b63
to
cced381
Compare
a472122
to
4e04022
Compare
4aa2992
to
e9f2174
Compare
""" | ||
group_sizes = [] | ||
start_idx = 0 | ||
for end_idx in offs.tolist(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is causing device-host synchronization, and we should avoid it, relying on upstream ops to create suitable inputs, or in worst case, on _scaled_grouped_mm implementation itself to throw an assert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, removed this assertion for now and will rely on kernel side assert, to avoid device-host sync
# Store what we need for backward. | ||
ctx.save_for_backward(A, B) | ||
ctx.float8_config = float8_config | ||
ctx.offs = offs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offs
is also a tensor so better you you use save_for_backward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
# Convert B to non-transposed, float8, column-major for right operand of grouped GEMM | ||
# needed for grad_A: grad_output @ B. | ||
# Since B was transposed before entry to forward, we need to transpose it back here for this. | ||
B_non_transposed_col_major = B.contiguous().transpose(-2, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to scale and transpose B in forward, and store only quantized version (to minimize memory)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, done. hopefully torch.compile can do some fusion here and read B once and write both outputs simultaneously (float transposed column major, float8 non-transposed column major).
|
||
start_idx = 0 | ||
next_scale_idx = 0 | ||
for end_idx in offs.tolist(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here also would be better to have a triton kernel computing scales that could read offs on the device, to avoid syncs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this implementation has more room for perf optimization, my first goal was to get accurate numerics. As a follow up I can write a triton kernel to avoid this device-host sync and for loop.
B_col_major = B | ||
|
||
# Fetch float8 config from specified recipe name. | ||
float8_config = Float8LinearConfig.from_recipe_name( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's suprising to see a config created here, why not just inline the logic you need without worrying about configs? IMO dealing with configs would be for when this API is about to be productionized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I used it here for this prototype because in the test code i need to use matmul_with_hp_or_float8_args
(which requires a Float8 config) to compute the reference forward/backward - and trying to get the outputs/grads to match was already pretty tricky, so to start I wanted to just reference the same Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
everywhere, to minimize room for accidental differences in how a particular tensor is quantized in test code vs implementation, etc.
I've now updated the implementation to inline everything, then in the test code compare against the float8 rowwise recipe to verify correctness.
A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K). | ||
B (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N). | ||
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group in the input tensor of shape. | ||
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
Args: | ||
A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K). | ||
B (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N). | ||
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group in the input tensor of shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we clarify if this is for A, B or both? A, right?
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group in the input tensor of shape. | ||
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization. | ||
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. | ||
use_fast_accum (bool): Whether to use fast accumulation or not. Default is False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
assert A.ndim == 2, "A must be 2D" | ||
assert B.ndim == 3, "B must be 3D" | ||
|
||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these assertions redundant with what is there in torch._scaled_grouped_mm
? Ideally we only assert in one place for each condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some are specific to this implementation (only support 2d A and 3D B for now, the primary use case), but some are redundant, yes.
I found the device-side assertions to be opaque sometimes, I often had to read the kernel code to figure out why it wasn't working, to see the exact condition that was failing if the error was a bit ambiguous. So my goal here was to make the requirements more transparent and easier to debug.
(I think since scaled grouped mm is a kernel executing on the GPU, if a check fails, then on the CPU side the error message we can't see the actual line of code with the condition that failed, it just points to the entrypoint torch._scaled_grouped_mm
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ndim assertions are not device side, so they have correct stack trace. Device-side assertions are put in when it's impossible to check the same thing on the host without introducing host-device sync, so they happen for a reason and unfortunately can't be improved. To get correct stack trace run with CUDA_LAUNCH_BLOCKING=1
), f"shape {A.shape} and {B.shape} are not compatible for _scaled_grouped_mm" | ||
|
||
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major. | ||
if not _is_column_major(B): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd prefer letting the caller do this instead, and this function can just assert that the layout is what is needed for the kernel
# low precision B tensor instead of the high precision B tensor. | ||
# In the backward this is needed for grad_A: grad_output @ B. | ||
# Since B was transposed before entry to forward, we need to transpose it back here for this. | ||
B_non_transposed_col_major = B.contiguous().transpose(-2, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this naming is confusing, if the original variable is B
, then this looks like B_transposed
. IMO it would be cleanest to do something like:
- input is B
- B transposed is B_t
or
- input is B_t
- B_t transposed is B
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree, I considered this as well. The problem is I'm trying to make the naming consistent with torch._scaled_grouped_mm, which calls the tensors A and B, but it checks that B must be transposed - although really what it's enforcing is column major format, so it's a bit confusing.
For now I changed the naming to option 2 above, which will make the python code here clearer, with the trade-off being it will no longer be consistent with the kernel naming. I think that's fine though, I doubt too many people will be diving into the kernel code.
Returns: | ||
A boolean indicating whether the input tensor is column-major. | ||
""" | ||
return x.stride(-2) == 1 and x.stride(-1) > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work for 4d/5d/etc tensors? if not, maybe assert that rank is 3?
3799b2d
to
300db8b
Compare
Summary
The
_grouped_scaled_mm
function in torchao will do:Note this prototype only handles A=2D, B=3D.
Test plan
Example usage