Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit c7093c1

Browse files
authored
Merge pull request #22 from pytorch-labs/better_names
better buffer names
2 parents ada78ad + f39de39 commit c7093c1

File tree

2 files changed

+47
-47
lines changed

2 files changed

+47
-47
lines changed

float8_playground/float8_linear.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ def forward(
3333
x_fp8,
3434
w_fp8,
3535
b_fp8,
36-
float8_amax_out,
37-
float8_amax_dL_dX,
38-
float8_amax_dL_dW,
39-
float8_amax_dL_dY,
36+
fp8_amax_y,
37+
fp8_amax_dL_dX,
38+
fp8_amax_dL_dW,
39+
fp8_amax_dL_dY,
4040
fw_amax_initialized,
4141
bw_amax_initialized,
4242
):
4343
ctx.save_for_backward(
44-
x_fp8, w_fp8, b_fp8, float8_amax_dL_dX, float8_amax_dL_dW, float8_amax_dL_dY,
44+
x_fp8, w_fp8, b_fp8, fp8_amax_dL_dX, fp8_amax_dL_dW, fp8_amax_dL_dY,
4545
bw_amax_initialized)
4646
orig_shape = x_fp8._data.shape
4747
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
@@ -52,22 +52,22 @@ def forward(
5252
# calculate reference amax of output
5353
with torch.no_grad():
5454
ref_result = torch.addmm(b_fp8, x_fp8_reshaped, w_fp8.t())
55-
float8_amax_out.fill_(tensor_to_amax(ref_result))
55+
fp8_amax_y.fill_(tensor_to_amax(ref_result))
5656

57-
y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
57+
y_scale = amax_to_scale(fp8_amax_y, torch.float8_e4m3fn)
5858
res_bits = addmm_float8(
59-
b_fp8, x_fp8_reshaped, w_fp8.t(), float8_amax_out, y_scale,
59+
b_fp8, x_fp8_reshaped, w_fp8.t(), fp8_amax_y, y_scale,
6060
torch.float8_e4m3fn)
6161
else:
6262
if not is_fw_amax_initialized:
6363
# calculate reference amax of output
6464
with torch.no_grad():
6565
ref_result = torch.mm(x_fp8_reshaped, w_fp8.t())
66-
float8_amax_out.fill_(tensor_to_amax(ref_result))
66+
fp8_amax_y.fill_(tensor_to_amax(ref_result))
6767

68-
y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
68+
y_scale = amax_to_scale(fp8_amax_y, torch.float8_e4m3fn)
6969
res_bits = mm_float8(
70-
x_fp8_reshaped, w_fp8.t(), float8_amax_out, y_scale,
70+
x_fp8_reshaped, w_fp8.t(), fp8_amax_y, y_scale,
7171
torch.float8_e4m3fn)
7272
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
7373

@@ -77,18 +77,18 @@ def forward(
7777

7878
@staticmethod
7979
def backward(ctx, go):
80-
x_fp8, w_fp8, b_fp8, float8_amax_dL_dX, float8_amax_dL_dW, \
81-
float8_amax_dL_dY, bw_amax_initialized = \
80+
x_fp8, w_fp8, b_fp8, fp8_amax_dL_dX, fp8_amax_dL_dW, \
81+
fp8_amax_dL_dY, bw_amax_initialized = \
8282
ctx.saved_tensors
8383

8484
is_bw_amax_initialized = torch.any(bw_amax_initialized)
8585

8686
if not isinstance(go, Float8Tensor):
8787
# TODO(future): switch to windowed delayed scaling
8888
if not is_bw_amax_initialized:
89-
float8_amax_dL_dY.fill_(tensor_to_amax(go))
90-
dL_dY_scale = amax_to_scale(float8_amax_dL_dY, torch.float8_e5m2)
91-
float8_amax_dL_dY.fill_(tensor_to_amax(go))
89+
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
90+
dL_dY_scale = amax_to_scale(fp8_amax_dL_dY, torch.float8_e5m2)
91+
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
9292
go_fp8 = Float8Tensor(
9393
(go * dL_dY_scale).to(torch.float8_e5m2),
9494
dL_dY_scale, go.dtype)
@@ -102,11 +102,11 @@ def backward(ctx, go):
102102
# calculate reference amax of output
103103
with torch.no_grad():
104104
dL_dX_ref = torch.mm(go_fp8_reshaped, w_fp8)
105-
float8_amax_dL_dX.fill_(tensor_to_amax(dL_dX_ref))
105+
fp8_amax_dL_dX.fill_(tensor_to_amax(dL_dX_ref))
106106

107-
dL_dX_scale = amax_to_scale(float8_amax_dL_dX, torch.float8_e5m2)
107+
dL_dX_scale = amax_to_scale(fp8_amax_dL_dX, torch.float8_e5m2)
108108
dL_dX_bits = mm_float8(
109-
go_fp8_reshaped, w_fp8, float8_amax_dL_dX, dL_dX_scale, torch.float8_e5m2)
109+
go_fp8_reshaped, w_fp8, fp8_amax_dL_dX, dL_dX_scale, torch.float8_e5m2)
110110
dL_dX_bits = dL_dX_bits.reshape(*go_fp8_orig_shape[:-1], dL_dX_bits.shape[-1])
111111
dL_dX_fp8 = Float8Tensor(dL_dX_bits, dL_dX_scale, go_fp8._orig_dtype)
112112

@@ -117,11 +117,11 @@ def backward(ctx, go):
117117
# calculate reference amax of output
118118
with torch.no_grad():
119119
dL_dW_ref = torch.mm(x_fp8_reshaped.t(), go_fp8_reshaped).t()
120-
float8_amax_dL_dW.fill_(tensor_to_amax(dL_dW_ref))
120+
fp8_amax_dL_dW.fill_(tensor_to_amax(dL_dW_ref))
121121

122-
dL_dW_scale = amax_to_scale(float8_amax_dL_dW, torch.float8_e5m2)
122+
dL_dW_scale = amax_to_scale(fp8_amax_dL_dW, torch.float8_e5m2)
123123
dL_dW_bits = mm_float8(
124-
x_fp8_reshaped.t(), go_fp8_reshaped, float8_amax_dL_dW,
124+
x_fp8_reshaped.t(), go_fp8_reshaped, fp8_amax_dL_dW,
125125
dL_dW_scale, torch.float8_e5m2).t()
126126
dL_dW_fp8 = Float8Tensor(dL_dW_bits, dL_dW_scale, go_fp8._orig_dtype)
127127

@@ -147,13 +147,13 @@ def __init__(self, *args, **kwargs):
147147
# scaling such as the mechanism described in
148148
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8,
149149
# or PTQ calibration.
150-
self.register_buffer('float8_amax_in', torch.tensor(E4M3_MAX_POS))
151-
self.register_buffer('float8_amax_weight', torch.tensor(E4M3_MAX_POS))
152-
self.register_buffer('float8_amax_bias', torch.tensor(E4M3_MAX_POS))
153-
self.register_buffer('float8_amax_out', torch.tensor(E4M3_MAX_POS))
154-
self.register_buffer('float8_amax_dL_dX', torch.tensor(E5M2_MAX_POS))
155-
self.register_buffer('float8_amax_dL_dW', torch.tensor(E5M2_MAX_POS))
156-
self.register_buffer('float8_amax_dL_dY', torch.tensor(E5M2_MAX_POS))
150+
self.register_buffer('fp8_amax_x', torch.tensor(E4M3_MAX_POS))
151+
self.register_buffer('fp8_amax_w', torch.tensor(E4M3_MAX_POS))
152+
self.register_buffer('fp8_amax_b', torch.tensor(E4M3_MAX_POS))
153+
self.register_buffer('fp8_amax_y', torch.tensor(E4M3_MAX_POS))
154+
self.register_buffer('fp8_amax_dL_dX', torch.tensor(E5M2_MAX_POS))
155+
self.register_buffer('fp8_amax_dL_dW', torch.tensor(E5M2_MAX_POS))
156+
self.register_buffer('fp8_amax_dL_dY', torch.tensor(E5M2_MAX_POS))
157157
self.register_buffer('fw_amax_initialized', torch.tensor([0], dtype=torch.uint8))
158158
self.register_buffer('bw_amax_initialized', torch.tensor([0], dtype=torch.uint8))
159159

@@ -169,34 +169,34 @@ def forward(self, x):
169169

170170
# TODO(future): switch to windowed delayed scaling
171171
if not is_fw_amax_initialized:
172-
self.float8_amax_in.fill_(tensor_to_amax(x))
173-
x_scale = amax_to_scale(self.float8_amax_in, torch.float8_e4m3fn)
174-
self.float8_amax_in.fill_(tensor_to_amax(x))
172+
self.fp8_amax_x.fill_(tensor_to_amax(x))
173+
x_scale = amax_to_scale(self.fp8_amax_x, torch.float8_e4m3fn)
174+
self.fp8_amax_x.fill_(tensor_to_amax(x))
175175

176176
x_fp8 = Float8Tensor.to_float8(x, x_scale, torch.float8_e4m3fn)
177177
else:
178178
x_fp8 = x
179179

180180
# TODO(future): switch to windowed delayed scaling
181181
if not is_fw_amax_initialized:
182-
self.float8_amax_weight.fill_(tensor_to_amax(self.weight))
183-
w_scale = amax_to_scale(self.float8_amax_weight, torch.float8_e4m3fn)
184-
self.float8_amax_weight.fill_(tensor_to_amax(self.weight))
182+
self.fp8_amax_w.fill_(tensor_to_amax(self.weight))
183+
w_scale = amax_to_scale(self.fp8_amax_w, torch.float8_e4m3fn)
184+
self.fp8_amax_w.fill_(tensor_to_amax(self.weight))
185185

186186
w_fp8 = Float8Tensor.to_float8(self.weight, w_scale, torch.float8_e4m3fn)
187187
maybe_b_fp8 = None
188188
if self.bias is not None:
189189
# TODO(future): switch to windowed delayed scaling
190190
if not is_fw_amax_initialized:
191-
self.float8_amax_bias.fill_(tensor_to_amax(self.bias))
192-
b_scale = amax_to_scale(self.float8_amax_bias, torch.float8_e4m3fn)
193-
self.float8_amax_bias.fill_(tensor_to_amax(self.bias))
191+
self.fp8_amax_b.fill_(tensor_to_amax(self.bias))
192+
b_scale = amax_to_scale(self.fp8_amax_b, torch.float8_e4m3fn)
193+
self.fp8_amax_b.fill_(tensor_to_amax(self.bias))
194194

195195
maybe_b_fp8 = Float8Tensor.to_float8(self.bias, b_scale, torch.float8_e4m3fn)
196196

197197
y_fp8 = float8_linear.apply(
198-
x_fp8, w_fp8, maybe_b_fp8, self.float8_amax_out, self.float8_amax_dL_dX,
199-
self.float8_amax_dL_dW, self.float8_amax_dL_dY, self.fw_amax_initialized,
198+
x_fp8, w_fp8, maybe_b_fp8, self.fp8_amax_y, self.fp8_amax_dL_dX,
199+
self.fp8_amax_dL_dW, self.fp8_amax_dL_dY, self.fw_amax_initialized,
200200
self.bw_amax_initialized)
201201

202202
if not is_fw_amax_initialized:

tests/test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ def _test_linear_impl(self, x, m_ref):
8888

8989
# verify all of the amax buffers got updated
9090
buffer_names = [
91-
'float8_amax_in',
92-
'float8_amax_weight',
93-
'float8_amax_out',
94-
'float8_amax_dL_dX',
95-
'float8_amax_dL_dW',
96-
'float8_amax_dL_dY',
91+
'fp8_amax_x',
92+
'fp8_amax_w',
93+
'fp8_amax_y',
94+
'fp8_amax_dL_dX',
95+
'fp8_amax_dL_dW',
96+
'fp8_amax_dL_dY',
9797
]
9898
if m_ref.bias is not None:
99-
buffer_names.append('float8_amax_bias')
99+
buffer_names.append('fp8_amax_b')
100100
for buffer_name in buffer_names:
101101
buffer_value = getattr(m_fp8, buffer_name)
102102
for init_val in (E4M3_MAX_POS, E5M2_MAX_POS):

0 commit comments

Comments
 (0)