Skip to content

Initial prototype of differentiable _scaled_grouped_mm function #1969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 64 commits into from
Apr 2, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Mar 26, 2025

Summary

The _grouped_scaled_mm function in torchao will do:

  • Dynamic, float8 rowwise quantization on inputs
  • Use these float8 inputs with the grouped scaled mm kernel in pytorch core and return the result
  • Do this in a differentiable way

Note this prototype only handles A=2D, B=3D.

Test plan

  • Added unit tests verifying the correctness of the forward pass (outputs) and backward pass (gradients)
  • Verified torch.compile works with no graph breaks

Example usage

from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm

...

out = _scaled_grouped_mm(
    x,             # 2D high precision input tensor
    params,        # 3D high precision weights
    offs=offs,     # 1D int32 group offsets
    out_dtype=out_dtype,
)

@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Mar 26, 2025
Copy link

pytorch-bot bot commented Mar 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1969

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 300db8b with merge base 923242e (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2025
@danielvegamyhre danielvegamyhre changed the title [WIP] Initial prototype of grouped_mm API for torchao [GroupedMM] Initial prototype of grouped_mm API for torchao Mar 27, 2025
@danielvegamyhre danielvegamyhre requested a review from vkuzo March 27, 2025 15:27
@danielvegamyhre danielvegamyhre changed the title [GroupedMM] Initial prototype of grouped_mm API for torchao [GroupedMM] Initial prototype of grouped_mm API for torchao (forward pass only) Mar 27, 2025
@danielvegamyhre danielvegamyhre changed the title [GroupedMM] Initial prototype of grouped_mm API for torchao (forward pass only) [GroupedMM] Initial prototype of _grouped_scaled_mm prototype function for torchao (forward pass only) Mar 27, 2025
@danielvegamyhre danielvegamyhre changed the title [GroupedMM] Initial prototype of _grouped_scaled_mm prototype function for torchao (forward pass only) [GroupedMM] Initial prototype of _grouped_scaled_mm function for torchao (forward pass only) Mar 27, 2025
"""
group_sizes = []
start_idx = 0
for end_idx in offs.tolist():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is causing device-host synchronization, and we should avoid it, relying on upstream ops to create suitable inputs, or in worst case, on _scaled_grouped_mm implementation itself to throw an assert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, removed this assertion for now and will rely on kernel side assert, to avoid device-host sync

# Store what we need for backward.
ctx.save_for_backward(A, B)
ctx.float8_config = float8_config
ctx.offs = offs
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offs is also a tensor so better you you use save_for_backward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# Convert B to non-transposed, float8, column-major for right operand of grouped GEMM
# needed for grad_A: grad_output @ B.
# Since B was transposed before entry to forward, we need to transpose it back here for this.
B_non_transposed_col_major = B.contiguous().transpose(-2, -1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to scale and transpose B in forward, and store only quantized version (to minimize memory)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, done. hopefully torch.compile can do some fusion here and read B once and write both outputs simultaneously (float transposed column major, float8 non-transposed column major).


start_idx = 0
next_scale_idx = 0
for end_idx in offs.tolist():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also would be better to have a triton kernel computing scales that could read offs on the device, to avoid syncs

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this implementation has more room for perf optimization, my first goal was to get accurate numerics. As a follow up I can write a triton kernel to avoid this device-host sync and for loop.

B_col_major = B

# Fetch float8 config from specified recipe name.
float8_config = Float8LinearConfig.from_recipe_name(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's suprising to see a config created here, why not just inline the logic you need without worrying about configs? IMO dealing with configs would be for when this API is about to be productionized.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I used it here for this prototype because in the test code i need to use matmul_with_hp_or_float8_args (which requires a Float8 config) to compute the reference forward/backward - and trying to get the outputs/grads to match was already pretty tricky, so to start I wanted to just reference the same Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) everywhere, to minimize room for accidental differences in how a particular tensor is quantized in test code vs implementation, etc.

I've now updated the implementation to inline everything, then in the test code compare against the float8 rowwise recipe to verify correctness.

A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K).
B (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N).
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group in the input tensor of shape.
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Args:
A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K).
B (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N).
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group in the input tensor of shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we clarify if this is for A, B or both? A, right?

offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group in the input tensor of shape.
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
use_fast_accum (bool): Whether to use fast accumulation or not. Default is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

assert A.ndim == 2, "A must be 2D"
assert B.ndim == 3, "B must be 3D"

assert (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these assertions redundant with what is there in torch._scaled_grouped_mm? Ideally we only assert in one place for each condition.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some are specific to this implementation (only support 2d A and 3D B for now, the primary use case), but some are redundant, yes.

I found the device-side assertions to be opaque sometimes, I often had to read the kernel code to figure out why it wasn't working, to see the exact condition that was failing if the error was a bit ambiguous. So my goal here was to make the requirements more transparent and easier to debug.

(I think since scaled grouped mm is a kernel executing on the GPU, if a check fails, then on the CPU side the error message we can't see the actual line of code with the condition that failed, it just points to the entrypoint torch._scaled_grouped_mm)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim assertions are not device side, so they have correct stack trace. Device-side assertions are put in when it's impossible to check the same thing on the host without introducing host-device sync, so they happen for a reason and unfortunately can't be improved. To get correct stack trace run with CUDA_LAUNCH_BLOCKING=1

), f"shape {A.shape} and {B.shape} are not compatible for _scaled_grouped_mm"

# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
if not _is_column_major(B):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd prefer letting the caller do this instead, and this function can just assert that the layout is what is needed for the kernel

# low precision B tensor instead of the high precision B tensor.
# In the backward this is needed for grad_A: grad_output @ B.
# Since B was transposed before entry to forward, we need to transpose it back here for this.
B_non_transposed_col_major = B.contiguous().transpose(-2, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this naming is confusing, if the original variable is B, then this looks like B_transposed. IMO it would be cleanest to do something like:

  • input is B
  • B transposed is B_t

or

  • input is B_t
  • B_t transposed is B

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, I considered this as well. The problem is I'm trying to make the naming consistent with torch._scaled_grouped_mm, which calls the tensors A and B, but it checks that B must be transposed - although really what it's enforcing is column major format, so it's a bit confusing.

For now I changed the naming to option 2 above, which will make the python code here clearer, with the trade-off being it will no longer be consistent with the kernel naming. I think that's fine though, I doubt too many people will be diving into the kernel code.

Returns:
A boolean indicating whether the input tensor is column-major.
"""
return x.stride(-2) == 1 and x.stride(-1) > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work for 4d/5d/etc tensors? if not, maybe assert that rank is 3?

@danielvegamyhre danielvegamyhre changed the title Initial prototype of differentiable grouped_scaled_mm function for torchao Initial prototype of differentiable _scaled_grouped_mm function for torchao Apr 2, 2025
@danielvegamyhre danielvegamyhre changed the title Initial prototype of differentiable _scaled_grouped_mm function for torchao Initial prototype of differentiable _scaled_grouped_mm function Apr 2, 2025
@danielvegamyhre danielvegamyhre merged commit 620356d into main Apr 2, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants