Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ada78ad

Browse files
authored
Merge pull request #21 from pytorch-labs/python_api
add python api to wrap aten api
2 parents 5bfb6f9 + 8b6cc98 commit ada78ad

File tree

2 files changed

+49
-18
lines changed

2 files changed

+49
-18
lines changed

float8_playground/float8_linear.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import torch
1111

12-
import float8_aten_api
12+
from float8_python_api import (
13+
mm_float8,
14+
addmm_float8,
15+
)
1316

1417
from float8_utils import (
1518
tensor_to_amax,
@@ -52,11 +55,9 @@ def forward(
5255
float8_amax_out.fill_(tensor_to_amax(ref_result))
5356

5457
y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
55-
res_bits = torch.ops.aten.addmm_float8(
56-
b_fp8._data, b_fp8._scale,
57-
x_fp8_reshaped._data, x_fp8._scale,
58-
w_fp8.t()._data, w_fp8._scale,
59-
float8_amax_out, y_scale, torch.float8_e4m3fn)
58+
res_bits = addmm_float8(
59+
b_fp8, x_fp8_reshaped, w_fp8.t(), float8_amax_out, y_scale,
60+
torch.float8_e4m3fn)
6061
else:
6162
if not is_fw_amax_initialized:
6263
# calculate reference amax of output
@@ -65,10 +66,9 @@ def forward(
6566
float8_amax_out.fill_(tensor_to_amax(ref_result))
6667

6768
y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
68-
res_bits = torch.ops.aten.mm_float8(
69-
x_fp8_reshaped._data, x_fp8._scale,
70-
w_fp8.t()._data, w_fp8._scale,
71-
float8_amax_out, y_scale, torch.float8_e4m3fn)
69+
res_bits = mm_float8(
70+
x_fp8_reshaped, w_fp8.t(), float8_amax_out, y_scale,
71+
torch.float8_e4m3fn)
7272
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
7373

7474
res = Float8Tensor(res_bits, y_scale, x_fp8._orig_dtype)
@@ -105,10 +105,8 @@ def backward(ctx, go):
105105
float8_amax_dL_dX.fill_(tensor_to_amax(dL_dX_ref))
106106

107107
dL_dX_scale = amax_to_scale(float8_amax_dL_dX, torch.float8_e5m2)
108-
dL_dX_bits = torch.ops.aten.mm_float8(
109-
go_fp8_reshaped._data, go_fp8._scale,
110-
w_fp8._data, w_fp8._scale,
111-
float8_amax_dL_dX, dL_dX_scale, torch.float8_e5m2)
108+
dL_dX_bits = mm_float8(
109+
go_fp8_reshaped, w_fp8, float8_amax_dL_dX, dL_dX_scale, torch.float8_e5m2)
112110
dL_dX_bits = dL_dX_bits.reshape(*go_fp8_orig_shape[:-1], dL_dX_bits.shape[-1])
113111
dL_dX_fp8 = Float8Tensor(dL_dX_bits, dL_dX_scale, go_fp8._orig_dtype)
114112

@@ -122,10 +120,9 @@ def backward(ctx, go):
122120
float8_amax_dL_dW.fill_(tensor_to_amax(dL_dW_ref))
123121

124122
dL_dW_scale = amax_to_scale(float8_amax_dL_dW, torch.float8_e5m2)
125-
dL_dW_bits = torch.ops.aten.mm_float8(
126-
x_fp8_reshaped.t()._data, x_fp8._scale,
127-
go_fp8_reshaped._data, go_fp8._scale,
128-
float8_amax_dL_dW, dL_dW_scale, torch.float8_e5m2).t()
123+
dL_dW_bits = mm_float8(
124+
x_fp8_reshaped.t(), go_fp8_reshaped, float8_amax_dL_dW,
125+
dL_dW_scale, torch.float8_e5m2).t()
129126
dL_dW_fp8 = Float8Tensor(dL_dW_bits, dL_dW_scale, go_fp8._orig_dtype)
130127

131128
if not is_bw_amax_initialized:
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
This file defines the Python functions for float8 which expect inputs
3+
of class `Float8Tensor`. This is a thin wrapper on top of the aten API
4+
to simplify the product code.
5+
"""
6+
7+
import torch
8+
import float8_aten_api
9+
10+
def mm_float8(
11+
x1, # input 1
12+
x2, # input 2
13+
amax3, # output amax, updated inplace in this function
14+
s3, # output scale, precomputed
15+
dtype3, # output dtype
16+
):
17+
return torch.ops.aten.mm_float8(
18+
x1._data, x1._scale,
19+
x2._data, x2._scale,
20+
amax3, s3, dtype3)
21+
22+
def addmm_float8(
23+
inp1, # addition term
24+
x1, # first mm term
25+
x2, # second mm term
26+
amax3, # output aax, updated inplace in this function
27+
s3, # output scale, precomputed
28+
dtype3, # output dtype
29+
):
30+
return torch.ops.aten.addmm_float8(
31+
inp1._data, inp1._scale,
32+
x1._data, x1._scale,
33+
x2._data, x2._scale,
34+
amax3, s3, dtype3)

0 commit comments

Comments
 (0)