@@ -38,22 +38,21 @@ def backward(ctx, g):
38
38
39
39
class Float8Tensor (torch .Tensor ):
40
40
"""
41
- A Python-only FP8 tensor. Contains:
41
+ A Python-only Float8 tensor subclass . Contains:
42
42
* `_data`: the underlying e4m3 or e5m2 data
43
43
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
44
44
by scale to go from fp32 range to fp8 range, and divide by scale to go
45
45
from fp8 range to fp32 range.
46
46
* `_orig_dtype`: the original dtype of the tensor used to create this
47
47
tensor.
48
48
49
- The current purpose of this object is 99% to bundle raw data + fp8 metadata
50
- together for easy passing through PyTorch systems, and 1% to implement
51
- gradient addition (since that has to happen outside of user code).
52
-
53
- The addition operation is defined inline and uses a naive
54
- version of stateless scaling. This allows e5m2 gradients to be added.
55
- TODO(future): verify this is numericaly accurate, optionally replace
56
- with something better.
49
+ Intended usage of this abstraction:
50
+ 1. to bundle raw data + fp8 metadata together for easy passing through
51
+ Python PyTorch systems.
52
+ 2. Float8-aware user code can use the private fields on these tensors
53
+ to call into float8 operations.
54
+ 3. Float8-agnostic user code can use these tensors as is - they will
55
+ convert to original precision in `__torch_dispatch__`.
57
56
"""
58
57
59
58
def __new__ (cls , data , scale , orig_dtype ):
@@ -89,38 +88,15 @@ def to_float8(cls, tensor, scale, dtype):
89
88
90
89
@classmethod
91
90
def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
92
- # Note: unlike many other subclasses, this subclass's only propagates
93
- # itself for addition (for gradient addition in backward). For all
94
- # other ops, it self-converts to original precision.
95
-
96
- # override addition so we can add e5m2 gradients
97
- if (
98
- func is aten .add .Tensor
99
- and isinstance (args [0 ], Float8Tensor )
100
- and isinstance (args [1 ], Float8Tensor )
101
- ):
102
- x1_fp8 , x2_fp8 = args [0 ], args [1 ]
103
- assert x1_fp8 ._data .dtype == torch .float8_e5m2 and x2_fp8 ._data .dtype == torch .float8_e5m2
104
- # scale will be filled in by the kernel, not using delayed scaling
105
- x3_scale = torch .empty (1 , device = x1_fp8 .device )
106
- res_bits = torch .ops .aten .add_float8_e5m2 (
107
- x1_fp8 ._data , x1_fp8 ._scale ,
108
- x2_fp8 ._data , x2_fp8 ._scale ,
109
- x3_scale )
110
- # TODO(future): handle type promotion if orig dtypes do not match
111
- # for now, just take the first one
112
- res = Float8Tensor (res_bits , x3_scale , x1_fp8 ._orig_dtype )
113
- return res
114
-
115
- # for all other ops, fall back to original precision
116
- def maybe_unwrap (t ):
91
+ # for all ops that get here, fall back to original precision
92
+ def unwrap (t ):
117
93
if isinstance (t , Float8Tensor ):
118
94
return t .to_original_precision ()
119
95
return t
120
96
121
- args = tree_map (maybe_unwrap , args )
97
+ args = tree_map (unwrap , args )
122
98
if kwargs is not None :
123
- kwargs = tree_map (maybe_unwrap , kwargs )
99
+ kwargs = tree_map (unwrap , kwargs )
124
100
out = super ().__torch_dispatch__ (func , types , args , kwargs )
125
101
return out
126
102
0 commit comments