-
Notifications
You must be signed in to change notification settings - Fork 19
bring back torch.autograd.Function #316
base: gh/vkuzo/29/base
Are you sure you want to change the base?
Conversation
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. ``` # this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files # and modified to only support dynamic scaling # # Why do we want a torch.autograd.Function here? Vasiliy's opinion is that # as we add more scaling granularities, keeping the scaling code close to Float8Linear # will be really useful for readability and debuggability of numerics. # # For example, a future PR to add rowwise scaling could do # # # forward # x_bf16 = ... # if scaling_granularity == ScalingGranularity.PER_TENSOR: # # we can scale the same way for fwd/bwd # x_maybe_fp8 = to_fp8(...) # else: # assert scaling_granularity == ScalingGranularity.PER_ROW: # # defer scaling to float8_mm # x_maybe_fp8 = x_bf16 # # # repeat for w # # y_bf16 = float8_mm(x_maybe_fp8, w_maybe_fp8) # # Requirements for float8_mm # - composes with DTensor, compile, autograd # - readable/debuggable # # Option 1 (this PR): float8_mm is a torch.autograd.Function # - pros # - cons # Option 2 (current code without this PR): float8_mm is an override of torch.mm # - pros # - cons # ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. We should discuss whether we want Option 1 (keep overriding torch.mm) or Option 2 (torch.autograd.Function). Vasiliy: I think Option 2 is cleaner/more readable/more debuggable, modeling code is usually written in the module or similar torch.autograd.Function overrides. I would consider scaling tensors to float8 modeling code, and it's unintuitive IMO for this to happen deep inside op overrides. However, Option 1 is less risky technically as we avoid torch.autograd.Function which is less mature in interactions with torch.compile. While the current PR is all green, we are using `allow_in_graph` which is a bit unsafe. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. This PR implements Option 2 as IMO this is more readable/debuggable. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
@@ -71,6 +71,54 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( | |||
scale.copy_(new_scale) | |||
|
|||
|
|||
# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Does the structure work out to put this in float8 ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in how things look after this PR it would make sense, but might be good to see how the code looks after we add different granularities and the if/else branches on when to convert to lower precision. Maybe we can revisit then?
return res_bits | ||
|
||
@staticmethod | ||
def backward(ctx, go_fp8): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: align go_fp8 / other naming to the other PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can do that in separate PRs, since not user facing. Just keeping things small.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if that changes the size of the PR much but sure thats fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably just a style difference on how to sequence the renames, either is ok IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, TBH I think this is a good balance of both subclassing + autograd func
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. This PR implements Option 2 as IMO this is more readable/debuggable. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6cb1588 Pull Request resolved: #336
…at8 matmul" Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 42dd595 Pull Request resolved: #336
…at8 matmul" Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…at8 matmul" Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…at8 matmul" Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068) [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068) [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Pull Request resolved: #344 This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Reviewed By: drisspg Differential Revision: D60291446 fbshipit-source-id: 472f392227bca1c7f83ea0c1234285bc576e58d2
Stack from ghstack (oldest at bottom):
Summary:
I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd:
a. keep the bf16 copy to be able to rescale across dim0 and dim1
b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision)
c. keep some of the gemms in bf16 to avoid the need to scale twice
The modeling logic in Float8Linear for a/b would look like:
And, there are at least two choices I see for
float8_mm_op
:To support future scaling granularities, whichever choice we go with will have to do something like below:
Furthermore, to keep things readable / debuggable, it would be good to:
To do the above, we'll need to pass around metadata such as module FQNs.
This PR implements Option 2 as IMO this is more readable/debuggable.
Test plan: