Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

add python api to wrap aten api #21

Merged
merged 1 commit into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions float8_playground/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

import torch

import float8_aten_api
from float8_python_api import (
mm_float8,
addmm_float8,
)

from float8_utils import (
tensor_to_amax,
Expand Down Expand Up @@ -52,11 +55,9 @@ def forward(
float8_amax_out.fill_(tensor_to_amax(ref_result))

y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
res_bits = torch.ops.aten.addmm_float8(
b_fp8._data, b_fp8._scale,
x_fp8_reshaped._data, x_fp8._scale,
w_fp8.t()._data, w_fp8._scale,
float8_amax_out, y_scale, torch.float8_e4m3fn)
res_bits = addmm_float8(
b_fp8, x_fp8_reshaped, w_fp8.t(), float8_amax_out, y_scale,
torch.float8_e4m3fn)
else:
if not is_fw_amax_initialized:
# calculate reference amax of output
Expand All @@ -65,10 +66,9 @@ def forward(
float8_amax_out.fill_(tensor_to_amax(ref_result))

y_scale = amax_to_scale(float8_amax_out, torch.float8_e4m3fn)
res_bits = torch.ops.aten.mm_float8(
x_fp8_reshaped._data, x_fp8._scale,
w_fp8.t()._data, w_fp8._scale,
float8_amax_out, y_scale, torch.float8_e4m3fn)
res_bits = mm_float8(
x_fp8_reshaped, w_fp8.t(), float8_amax_out, y_scale,
torch.float8_e4m3fn)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])

res = Float8Tensor(res_bits, y_scale, x_fp8._orig_dtype)
Expand Down Expand Up @@ -105,10 +105,8 @@ def backward(ctx, go):
float8_amax_dL_dX.fill_(tensor_to_amax(dL_dX_ref))

dL_dX_scale = amax_to_scale(float8_amax_dL_dX, torch.float8_e5m2)
dL_dX_bits = torch.ops.aten.mm_float8(
go_fp8_reshaped._data, go_fp8._scale,
w_fp8._data, w_fp8._scale,
float8_amax_dL_dX, dL_dX_scale, torch.float8_e5m2)
dL_dX_bits = mm_float8(
go_fp8_reshaped, w_fp8, float8_amax_dL_dX, dL_dX_scale, torch.float8_e5m2)
dL_dX_bits = dL_dX_bits.reshape(*go_fp8_orig_shape[:-1], dL_dX_bits.shape[-1])
dL_dX_fp8 = Float8Tensor(dL_dX_bits, dL_dX_scale, go_fp8._orig_dtype)

Expand All @@ -122,10 +120,9 @@ def backward(ctx, go):
float8_amax_dL_dW.fill_(tensor_to_amax(dL_dW_ref))

dL_dW_scale = amax_to_scale(float8_amax_dL_dW, torch.float8_e5m2)
dL_dW_bits = torch.ops.aten.mm_float8(
x_fp8_reshaped.t()._data, x_fp8._scale,
go_fp8_reshaped._data, go_fp8._scale,
float8_amax_dL_dW, dL_dW_scale, torch.float8_e5m2).t()
dL_dW_bits = mm_float8(
x_fp8_reshaped.t(), go_fp8_reshaped, float8_amax_dL_dW,
dL_dW_scale, torch.float8_e5m2).t()
dL_dW_fp8 = Float8Tensor(dL_dW_bits, dL_dW_scale, go_fp8._orig_dtype)

if not is_bw_amax_initialized:
Expand Down
34 changes: 34 additions & 0 deletions float8_playground/float8_python_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
This file defines the Python functions for float8 which expect inputs
of class `Float8Tensor`. This is a thin wrapper on top of the aten API
to simplify the product code.
"""

import torch
import float8_aten_api

def mm_float8(
x1, # input 1
x2, # input 2
amax3, # output amax, updated inplace in this function
s3, # output scale, precomputed
dtype3, # output dtype
):
return torch.ops.aten.mm_float8(
x1._data, x1._scale,
x2._data, x2._scale,
amax3, s3, dtype3)

def addmm_float8(
inp1, # addition term
x1, # first mm term
x2, # second mm term
amax3, # output aax, updated inplace in this function
s3, # output scale, precomputed
dtype3, # output dtype
):
return torch.ops.aten.addmm_float8(
inp1._data, inp1._scale,
x1._data, x1._scale,
x2._data, x2._scale,
amax3, s3, dtype3)