4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
from collections import namedtuple
7
+ import enum
7
8
from typing import Dict , Optional
8
9
9
10
import torch
18
19
19
20
aten = torch .ops .aten
20
21
22
+ #
23
+ # A note on configuration of float8 logic in a linear
24
+ # TODO(future): move all the configs to separate file
25
+ #
26
+ # There are three gemms in a forward + backward of a Linear layer:
27
+ #
28
+ # 1. x @ w_t = y (forward pass)
29
+ # 2. dL_dY @ w = dL_dX (backward pass)
30
+ # 3. x_t @ dL_dY = dL_dW (backward pass)
31
+ #
32
+ # In the formulas above, there are:
33
+ # A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t).
34
+ # - Note that dL_dY_t is implied because of memory format requirements
35
+ # of float8 gemms
36
+ # B. three output tensors (y, dL_dX, dL_dW)
37
+ #
38
+ # We want each input tensor, gemm, and output tensor to be configurable.
39
+ # The state of this configuration today is:
40
+ #
41
+ # i. pairs of input tensors (non-t and t variants) have their scaling
42
+ # configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear
43
+ # ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing
44
+ # iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed
45
+ # to configure all three gemms, also not user facing
46
+
21
47
22
48
# ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass.
23
49
# emulate: whether to emulate the matmuls in fp32
30
56
defaults = [False , False , False , False ],
31
57
)
32
58
59
+ # The object below exists for convenience, to allow Float8Tensor to use
60
+ # the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is
61
+ # being called.
62
+ LinearMMConfig = namedtuple (
63
+ "LinearMMConfig" ,
64
+ ["y" , "dL_dX" , "dL_dW" ],
65
+ defaults = [
66
+ ScaledMMConfig (False , True , False , False ),
67
+ ScaledMMConfig (False , False , False , False ),
68
+ ScaledMMConfig (False , False , False , False ),
69
+ ]
70
+ )
71
+
72
+ # Given a Float8Tensor, the enum below describes the expected role of this
73
+ # tensor in the three gemms present in the fw + bw pass of a Linear layer.
74
+ # This is used to choose the right config for a float8 gemm when the
75
+ # gemm is performed.
76
+ class GemmInputRole (enum .Enum ):
77
+ X = "x"
78
+ W = "w"
79
+ DL_DY = "dL_dY"
80
+
81
+ # choose which scaled_mm_config to use based on gemm inputs
82
+ def choose_scaled_mm_config (
83
+ a_role : GemmInputRole ,
84
+ a_linear_mm_config : LinearMMConfig ,
85
+ b_role : GemmInputRole ,
86
+ b_linear_mm_config : LinearMMConfig ,
87
+ ):
88
+ if a_role is GemmInputRole .X and b_role is GemmInputRole .W :
89
+ assert a_linear_mm_config .y == b_linear_mm_config .y
90
+ return a_linear_mm_config .y
91
+ elif a_role is GemmInputRole .DL_DY and b_role is GemmInputRole .W :
92
+ assert a_linear_mm_config .dL_dX == b_linear_mm_config .dL_dX
93
+ return a_linear_mm_config .dL_dX
94
+ else :
95
+ assert a_role is GemmInputRole .X and b_role is GemmInputRole .DL_DY , \
96
+ f"unexpected a_role { a_role } and b_role { b_role } "
97
+ assert a_linear_mm_config .dL_dW == b_linear_mm_config .dL_dW
98
+ return a_linear_mm_config .dL_dW
99
+
100
+
33
101
34
102
def merge_mm_configs (
35
103
a_mm_config : ScaledMMConfig , b_mm_config : ScaledMMConfig
@@ -194,15 +262,18 @@ class Float8Tensor(torch.Tensor):
194
262
_data : torch .Tensor
195
263
_scale : torch .Tensor
196
264
_orig_dtype : torch .dtype
197
- _mm_config : ScaledMMConfig
265
+ # TODO(before land): change this to _linear_mm_config, wanted to do that after
266
+ # initial review
267
+ _mm_config : LinearMMConfig
198
268
__slots__ = ["_data" , "_scale" , "_orig_dtype" , "_mm_config" ]
199
269
200
270
def __new__ (
201
271
cls ,
202
272
data : torch .Tensor ,
203
273
scale : torch .Tensor ,
204
274
orig_dtype : torch .dtype ,
205
- mm_config : Optional [ScaledMMConfig ],
275
+ mm_config : Optional [LinearMMConfig ],
276
+ gemm_input_role : Optional [GemmInputRole ] = GemmInputRole .X ,
206
277
):
207
278
assert (
208
279
scale .numel () == 1
@@ -223,7 +294,8 @@ def __new__(
223
294
self ._data = data
224
295
self ._scale = scale
225
296
self ._orig_dtype = orig_dtype
226
- self ._mm_config = mm_config if mm_config is not None else ScaledMMConfig ()
297
+ self ._mm_config = mm_config if mm_config is not None else LinearMMConfig ()
298
+ self ._gemm_input_role = gemm_input_role
227
299
228
300
return self
229
301
@@ -257,7 +329,8 @@ def to_float8(
257
329
scale : torch .Tensor ,
258
330
float8_dtype : torch .dtype ,
259
331
amax_buffer : Optional [torch .Tensor ] = None ,
260
- mm_config : Optional [ScaledMMConfig ] = None ,
332
+ mm_config : Optional [LinearMMConfig ] = None ,
333
+ gemm_input_role : Optional [GemmInputRole ] = GemmInputRole .X ,
261
334
):
262
335
"""Converts a higher precision tensor to float8 in a differentiable way.
263
336
@@ -272,7 +345,7 @@ def to_float8(
272
345
Float8Tensor: a float8 tensor
273
346
"""
274
347
return ToFloat8ConstrFunc .apply (
275
- tensor , scale , float8_dtype , amax_buffer , mm_config
348
+ tensor , scale , float8_dtype , amax_buffer , mm_config , gemm_input_role ,
276
349
)
277
350
278
351
@classmethod
0 commit comments