Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 9e9581d

Browse files
committed
[wip] make all 3 gemms in Float8Linear configurable
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 302e7c4 Pull Request resolved: #315
1 parent de93990 commit 9e9581d

File tree

4 files changed

+112
-23
lines changed

4 files changed

+112
-23
lines changed

float8_experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
77
from float8_experimental.float8_linear import Float8Linear
8-
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
8+
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig, GemmInputRole
99

1010
# Needed to load Float8Tensor with weights_only = True
1111
from torch.serialization import add_safe_globals
1212

13-
add_safe_globals([Float8Tensor, ScaledMMConfig])
13+
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole])
1414

1515
__all__ = ["Float8Tensor", "Float8Linear"]

float8_experimental/float8_ops.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Float8Tensor,
1313
merge_mm_configs,
1414
ScaledMMConfig,
15+
choose_scaled_mm_config,
1516
)
1617
from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul
1718

@@ -125,10 +126,16 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
125126
a_scale = a._scale
126127
b_data = b._data
127128

128-
if a._mm_config.pad_inner_dim:
129-
assert (
130-
b._mm_config.pad_inner_dim
131-
), "Both mm configs must have pad_inner_dim set to True"
129+
scaled_mm_config = choose_scaled_mm_config(
130+
a._gemm_input_role, a._mm_config,
131+
b._gemm_input_role, b._mm_config,
132+
)
133+
134+
if scaled_mm_config.pad_inner_dim:
135+
# TODO(before land): assert this when choosing config
136+
# assert (
137+
# b._mm_config.pad_inner_dim
138+
# ), "Both mm configs must have pad_inner_dim set to True"
132139
assert a._data.size(1) == b._data.size(
133140
0
134141
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
@@ -155,10 +162,14 @@ def float8_mm(aten_op, args, kwargs=None):
155162
)
156163
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
157164
output_dtype = a._orig_dtype
158-
a_mm_config: ScaledMMConfig = a._mm_config
159-
b_mm_config: ScaledMMConfig = b._mm_config
160-
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
161-
if mm_config.emulate:
165+
# a_mm_config: ScaledMMConfig = a._mm_config
166+
# b_mm_config: ScaledMMConfig = b._mm_config
167+
# mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
168+
scaled_mm_config = choose_scaled_mm_config(
169+
a._gemm_input_role, a._mm_config,
170+
b._gemm_input_role, b._mm_config,
171+
)
172+
if scaled_mm_config.emulate:
162173
return torch.ops.aten.mm_float8_emulated(
163174
a._data, a._scale, b._data, b._scale, output_dtype
164175
)
@@ -170,7 +181,7 @@ def float8_mm(aten_op, args, kwargs=None):
170181
output_dtype,
171182
output_scale=None,
172183
bias=None,
173-
use_fast_accum=mm_config.use_fast_accum,
184+
use_fast_accum=scaled_mm_config.use_fast_accum,
174185
)
175186
return tensor_out
176187

@@ -188,10 +199,14 @@ def float8_addmm(aten_op, args, kwargs=None):
188199
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
189200
output_dtype = a._orig_dtype
190201
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
191-
a_mm_config: ScaledMMConfig = a._mm_config
192-
b_mm_config: ScaledMMConfig = b._mm_config
193-
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
194-
if mm_config.emulate:
202+
# a_mm_config: ScaledMMConfig = a._mm_config
203+
# b_mm_config: ScaledMMConfig = b._mm_config
204+
# mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
205+
scaled_mm_config = choose_scaled_mm_config(
206+
a._gemm_input_role, a._mm_config,
207+
b._gemm_input_role, b._mm_config,
208+
)
209+
if scaled_mm_config.emulate:
195210
out = torch.ops.aten.mm_float8_emulated(
196211
a._data, a._scale, b._data, b._scale, output_dtype
197212
)
@@ -204,7 +219,7 @@ def float8_addmm(aten_op, args, kwargs=None):
204219
output_dtype,
205220
output_scale=None,
206221
bias=bias,
207-
use_fast_accum=mm_config.use_fast_accum,
222+
use_fast_accum=scaled_mm_config.use_fast_accum,
208223
)
209224
return tensor_out
210225

float8_experimental/float8_tensor.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
from collections import namedtuple
7+
import enum
78
from typing import Dict, Optional
89

910
import torch
@@ -18,6 +19,31 @@
1819

1920
aten = torch.ops.aten
2021

22+
#
23+
# A note on configuration of float8 logic in a linear
24+
# TODO(future): move all the configs to separate file
25+
#
26+
# There are three gemms in a forward + backward of a Linear layer:
27+
#
28+
# 1. x @ w_t = y (forward pass)
29+
# 2. dL_dY @ w = dL_dX (backward pass)
30+
# 3. x_t @ dL_dY = dL_dW (backward pass)
31+
#
32+
# In the formulas above, there are:
33+
# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t).
34+
# - Note that dL_dY_t is implied because of memory format requirements
35+
# of float8 gemms
36+
# B. three output tensors (y, dL_dX, dL_dW)
37+
#
38+
# We want each input tensor, gemm, and output tensor to be configurable.
39+
# The state of this configuration today is:
40+
#
41+
# i. pairs of input tensors (non-t and t variants) have their scaling
42+
# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear
43+
# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing
44+
# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed
45+
# to configure all three gemms, also not user facing
46+
2147

2248
# ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass.
2349
# emulate: whether to emulate the matmuls in fp32
@@ -30,6 +56,48 @@
3056
defaults=[False, False, False, False],
3157
)
3258

59+
# The object below exists for convenience, to allow Float8Tensor to use
60+
# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is
61+
# being called.
62+
LinearMMConfig = namedtuple(
63+
"LinearMMConfig",
64+
["y", "dL_dX", "dL_dW"],
65+
defaults=[
66+
ScaledMMConfig(False, True, False, False),
67+
ScaledMMConfig(False, False, False, False),
68+
ScaledMMConfig(False, False, False, False),
69+
]
70+
)
71+
72+
# Given a Float8Tensor, the enum below describes the expected role of this
73+
# tensor in the three gemms present in the fw + bw pass of a Linear layer.
74+
# This is used to choose the right config for a float8 gemm when the
75+
# gemm is performed.
76+
class GemmInputRole(enum.Enum):
77+
X = "x"
78+
W = "w"
79+
DL_DY = "dL_dY"
80+
81+
# choose which scaled_mm_config to use based on gemm inputs
82+
def choose_scaled_mm_config(
83+
a_role: GemmInputRole,
84+
a_linear_mm_config: LinearMMConfig,
85+
b_role: GemmInputRole,
86+
b_linear_mm_config: LinearMMConfig,
87+
):
88+
if a_role is GemmInputRole.X and b_role is GemmInputRole.W:
89+
assert a_linear_mm_config.y == b_linear_mm_config.y
90+
return a_linear_mm_config.y
91+
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W:
92+
assert a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
93+
return a_linear_mm_config.dL_dX
94+
else:
95+
assert a_role is GemmInputRole.X and b_role is GemmInputRole.DL_DY, \
96+
f"unexpected a_role {a_role} and b_role {b_role}"
97+
assert a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
98+
return a_linear_mm_config.dL_dW
99+
100+
33101

34102
def merge_mm_configs(
35103
a_mm_config: ScaledMMConfig, b_mm_config: ScaledMMConfig
@@ -194,15 +262,18 @@ class Float8Tensor(torch.Tensor):
194262
_data: torch.Tensor
195263
_scale: torch.Tensor
196264
_orig_dtype: torch.dtype
197-
_mm_config: ScaledMMConfig
265+
# TODO(before land): change this to _linear_mm_config, wanted to do that after
266+
# initial review
267+
_mm_config: LinearMMConfig
198268
__slots__ = ["_data", "_scale", "_orig_dtype", "_mm_config"]
199269

200270
def __new__(
201271
cls,
202272
data: torch.Tensor,
203273
scale: torch.Tensor,
204274
orig_dtype: torch.dtype,
205-
mm_config: Optional[ScaledMMConfig],
275+
mm_config: Optional[LinearMMConfig],
276+
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
206277
):
207278
assert (
208279
scale.numel() == 1
@@ -223,7 +294,8 @@ def __new__(
223294
self._data = data
224295
self._scale = scale
225296
self._orig_dtype = orig_dtype
226-
self._mm_config = mm_config if mm_config is not None else ScaledMMConfig()
297+
self._mm_config = mm_config if mm_config is not None else LinearMMConfig()
298+
self._gemm_input_role = gemm_input_role
227299

228300
return self
229301

@@ -257,7 +329,8 @@ def to_float8(
257329
scale: torch.Tensor,
258330
float8_dtype: torch.dtype,
259331
amax_buffer: Optional[torch.Tensor] = None,
260-
mm_config: Optional[ScaledMMConfig] = None,
332+
mm_config: Optional[LinearMMConfig] = None,
333+
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
261334
):
262335
"""Converts a higher precision tensor to float8 in a differentiable way.
263336
@@ -272,7 +345,7 @@ def to_float8(
272345
Float8Tensor: a float8 tensor
273346
"""
274347
return ToFloat8ConstrFunc.apply(
275-
tensor, scale, float8_dtype, amax_buffer, mm_config
348+
tensor, scale, float8_dtype, amax_buffer, mm_config, gemm_input_role,
276349
)
277350

278351
@classmethod

test/test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Float8Tensor,
2929
merge_mm_configs,
3030
ScaledMMConfig,
31+
GemmInputRole,
3132
)
3233
from float8_experimental.float8_utils import (
3334
compute_error,
@@ -438,9 +439,9 @@ def test_different_configs_error(self):
438439
x_fp32 = torch.randn(16, 16, device="cuda")
439440
x_scale = torch.tensor(1.0, device="cuda")
440441
fp8_dtype = e4m3_dtype
441-
a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype)
442+
a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X)
442443
b = Float8Tensor.to_float8(
443-
x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True)
444+
x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True), gemm_input_role=GemmInputRole.W
444445
)
445446
with pytest.raises(
446447
AssertionError,

0 commit comments

Comments
 (0)