Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit bc4438e

Browse files
committed
Update base for Update on "bring back torch.autograd.Function for float8 matmul"
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068) [ghstack-poisoned]
2 parents 224cfdf + 0aca10a commit bc4438e

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

float8_experimental/fsdp_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
5656

5757
# inf-norm is equivalent to max(abs(w))
5858
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
59-
amax_tensor = torch.vstack(max_weights) # Partial
59+
amax_tensor = torch.stack(max_weights) # Partial
6060
# clamp is dispatched through DTensor
6161
# it will issue a single all-reduce
6262
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
6363
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
6464
if amax_tensor.dtype is torch.float16:
6565
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
66-
scales = torch.split(scale_tensor, 1) # Replicate
67-
for scale, float8_linear in zip(scales, float8_linears):
68-
float8_linear.weight._local_tensor._precomputed_scale = (
69-
scale._local_tensor.squeeze()
70-
)
66+
local_scale_tensor = scale_tensor.to_local()
67+
for i, float8_linear in enumerate(float8_linears):
68+
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]
7169

7270

7371
# FSDP pads its local tensor on dim-0. The subclass should be preserved such

0 commit comments

Comments
 (0)