Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit af60830

Browse files
authored
Merge pull request #14 from pytorch-labs/delete_custom_add
make grad addition happen in original precision
2 parents 0b4561e + 08cbdac commit af60830

File tree

2 files changed

+15
-41
lines changed

2 files changed

+15
-41
lines changed

float8_playground/float8_tensor.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,21 @@ def backward(ctx, g):
3838

3939
class Float8Tensor(torch.Tensor):
4040
"""
41-
A Python-only FP8 tensor. Contains:
41+
A Python-only Float8 tensor subclass. Contains:
4242
* `_data`: the underlying e4m3 or e5m2 data
4343
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
4444
by scale to go from fp32 range to fp8 range, and divide by scale to go
4545
from fp8 range to fp32 range.
4646
* `_orig_dtype`: the original dtype of the tensor used to create this
4747
tensor.
4848
49-
The current purpose of this object is 99% to bundle raw data + fp8 metadata
50-
together for easy passing through PyTorch systems, and 1% to implement
51-
gradient addition (since that has to happen outside of user code).
52-
53-
The addition operation is defined inline and uses a naive
54-
version of stateless scaling. This allows e5m2 gradients to be added.
55-
TODO(future): verify this is numericaly accurate, optionally replace
56-
with something better.
49+
Intended usage of this abstraction:
50+
1. to bundle raw data + fp8 metadata together for easy passing through
51+
Python PyTorch systems.
52+
2. Float8-aware user code can use the private fields on these tensors
53+
to call into float8 operations.
54+
3. Float8-agnostic user code can use these tensors as is - they will
55+
convert to original precision in `__torch_dispatch__`.
5756
"""
5857

5958
def __new__(cls, data, scale, orig_dtype):
@@ -89,38 +88,15 @@ def to_float8(cls, tensor, scale, dtype):
8988

9089
@classmethod
9190
def __torch_dispatch__(cls, func, types, args, kwargs=None):
92-
# Note: unlike many other subclasses, this subclass's only propagates
93-
# itself for addition (for gradient addition in backward). For all
94-
# other ops, it self-converts to original precision.
95-
96-
# override addition so we can add e5m2 gradients
97-
if (
98-
func is aten.add.Tensor
99-
and isinstance(args[0], Float8Tensor)
100-
and isinstance(args[1], Float8Tensor)
101-
):
102-
x1_fp8, x2_fp8 = args[0], args[1]
103-
assert x1_fp8._data.dtype == torch.float8_e5m2 and x2_fp8._data.dtype == torch.float8_e5m2
104-
# scale will be filled in by the kernel, not using delayed scaling
105-
x3_scale = torch.empty(1, device=x1_fp8.device)
106-
res_bits = torch.ops.aten.add_float8_e5m2(
107-
x1_fp8._data, x1_fp8._scale,
108-
x2_fp8._data, x2_fp8._scale,
109-
x3_scale)
110-
# TODO(future): handle type promotion if orig dtypes do not match
111-
# for now, just take the first one
112-
res = Float8Tensor(res_bits, x3_scale, x1_fp8._orig_dtype)
113-
return res
114-
115-
# for all other ops, fall back to original precision
116-
def maybe_unwrap(t):
91+
# for all ops that get here, fall back to original precision
92+
def unwrap(t):
11793
if isinstance(t, Float8Tensor):
11894
return t.to_original_precision()
11995
return t
12096

121-
args = tree_map(maybe_unwrap, args)
97+
args = tree_map(unwrap, args)
12298
if kwargs is not None:
123-
kwargs = tree_map(maybe_unwrap, kwargs)
99+
kwargs = tree_map(unwrap, kwargs)
124100
out = super().__torch_dispatch__(func, types, args, kwargs)
125101
return out
126102

tests/test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,17 @@
2020
torch.manual_seed(0)
2121

2222
class Float8TensorUnitTest(unittest.TestCase):
23-
def test_add(self):
23+
def test_grad_add(self):
2424
x1_fp32 = torch.randn(4, 4, device='cuda')
2525
x1_s = tensor_to_scale(x1_fp32, torch.float8_e5m2)
2626
x2_fp32 = torch.randn(4, 4, device='cuda')
2727
x2_s = tensor_to_scale(x2_fp32, torch.float8_e5m2)
2828
x1_fp8 = Float8Tensor.to_float8(x1_fp32, x1_s, torch.float8_e5m2)
2929
x2_fp8 = Float8Tensor.to_float8(x2_fp32, x2_s, torch.float8_e5m2)
30-
x3_fp8 = x1_fp8 + x2_fp8
31-
x3_fp32 = x3_fp8.to_original_precision()
30+
x3_fp32 = x1_fp8 + x2_fp8
3231
x3_fp32_ref = x1_fp32 + x2_fp32
3332
sqnr = compute_error(x3_fp32_ref, x3_fp32)
34-
# TODO(future): make this more accurate, accuracy is pretty low
35-
self.assertTrue(sqnr >= 10.0)
33+
self.assertTrue(sqnr >= 20.0)
3634

3735
def test_preserves_dtype(self):
3836
# hp means high precision, lp means low precision

0 commit comments

Comments
 (0)