Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

remove fp8 bias #23

Merged
merged 1 commit into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions float8_playground/float8_aten_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def mm_float8(

# TODO naming of these vars is weird
def addmm_float8(
inp1, # bias data
inp_s1, # bias scale
inp1, # bias (in fp32/fp16/bf16, no fp8 support)
m1, # input 1 data
s1, # input 1 scale
m2, # input 2 data
Expand All @@ -46,10 +45,9 @@ def addmm_float8(
):
# naive implementation: dq -> op -> q
# TODO(future): hook up to real kernel
inp1_fp32 = inp1.float() / inp_s1
m1_fp32 = m1.float() / s1
m2_fp32 = m2.float() / s2
m3_fp32 = torch.addmm(inp1_fp32, m1_fp32, m2_fp32)
m3_fp32 = torch.addmm(inp1, m1_fp32, m2_fp32)

# TODO(future): switch to delayed scaling
amax3.fill_(tensor_to_amax(m3_fp32))
Expand All @@ -70,6 +68,6 @@ def addmm_float8(
lib.impl("mm_float8", mm_float8, "CPU")
lib.impl("mm_float8", mm_float8, "CUDA")

lib.define("addmm_float8(Tensor inp1, Tensor inp_s1, Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor amax3, Tensor s3, ScalarType dtype3) -> Tensor")
lib.define("addmm_float8(Tensor inp1, Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor amax3, Tensor s3, ScalarType dtype3) -> Tensor")
lib.impl("addmm_float8", addmm_float8, "CPU")
lib.impl("addmm_float8", addmm_float8, "CUDA")
22 changes: 6 additions & 16 deletions float8_playground/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def forward(
ctx,
x_fp8,
w_fp8,
b_fp8,
b,
fp8_amax_y,
fp8_amax_dL_dX,
fp8_amax_dL_dW,
Expand All @@ -41,22 +41,22 @@ def forward(
bw_amax_initialized,
):
ctx.save_for_backward(
x_fp8, w_fp8, b_fp8, fp8_amax_dL_dX, fp8_amax_dL_dW, fp8_amax_dL_dY,
x_fp8, w_fp8, b, fp8_amax_dL_dX, fp8_amax_dL_dW, fp8_amax_dL_dY,
bw_amax_initialized)
orig_shape = x_fp8._data.shape
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
is_fw_amax_initialized = torch.any(fw_amax_initialized)

if b_fp8 is not None:
if b is not None:
if not is_fw_amax_initialized:
# calculate reference amax of output
with torch.no_grad():
ref_result = torch.addmm(b_fp8, x_fp8_reshaped, w_fp8.t())
ref_result = torch.addmm(b, x_fp8_reshaped, w_fp8.t())
fp8_amax_y.fill_(tensor_to_amax(ref_result))

y_scale = amax_to_scale(fp8_amax_y, torch.float8_e4m3fn)
res_bits = addmm_float8(
b_fp8, x_fp8_reshaped, w_fp8.t(), fp8_amax_y, y_scale,
b, x_fp8_reshaped, w_fp8.t(), fp8_amax_y, y_scale,
torch.float8_e4m3fn)
else:
if not is_fw_amax_initialized:
Expand Down Expand Up @@ -149,7 +149,6 @@ def __init__(self, *args, **kwargs):
# or PTQ calibration.
self.register_buffer('fp8_amax_x', torch.tensor(E4M3_MAX_POS))
self.register_buffer('fp8_amax_w', torch.tensor(E4M3_MAX_POS))
self.register_buffer('fp8_amax_b', torch.tensor(E4M3_MAX_POS))
self.register_buffer('fp8_amax_y', torch.tensor(E4M3_MAX_POS))
self.register_buffer('fp8_amax_dL_dX', torch.tensor(E5M2_MAX_POS))
self.register_buffer('fp8_amax_dL_dW', torch.tensor(E5M2_MAX_POS))
Expand Down Expand Up @@ -184,18 +183,9 @@ def forward(self, x):
self.fp8_amax_w.fill_(tensor_to_amax(self.weight))

w_fp8 = Float8Tensor.to_float8(self.weight, w_scale, torch.float8_e4m3fn)
maybe_b_fp8 = None
if self.bias is not None:
# TODO(future): switch to windowed delayed scaling
if not is_fw_amax_initialized:
self.fp8_amax_b.fill_(tensor_to_amax(self.bias))
b_scale = amax_to_scale(self.fp8_amax_b, torch.float8_e4m3fn)
self.fp8_amax_b.fill_(tensor_to_amax(self.bias))

maybe_b_fp8 = Float8Tensor.to_float8(self.bias, b_scale, torch.float8_e4m3fn)

y_fp8 = float8_linear.apply(
x_fp8, w_fp8, maybe_b_fp8, self.fp8_amax_y, self.fp8_amax_dL_dX,
x_fp8, w_fp8, self.bias, self.fp8_amax_y, self.fp8_amax_dL_dX,
self.fp8_amax_dL_dW, self.fp8_amax_dL_dY, self.fw_amax_initialized,
self.bw_amax_initialized)

Expand Down
4 changes: 2 additions & 2 deletions float8_playground/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def mm_float8(
amax3, s3, dtype3)

def addmm_float8(
inp1, # addition term
inp1, # addition term (in fp32/fp16/bf16, no fp8 support)
x1, # first mm term
x2, # second mm term
amax3, # output aax, updated inplace in this function
s3, # output scale, precomputed
dtype3, # output dtype
):
return torch.ops.aten.addmm_float8(
inp1._data, inp1._scale,
inp1,
x1._data, x1._scale,
x2._data, x2._scale,
amax3, s3, dtype3)
2 changes: 0 additions & 2 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def _test_linear_impl(self, x, m_ref):
'fp8_amax_dL_dW',
'fp8_amax_dL_dY',
]
if m_ref.bias is not None:
buffer_names.append('fp8_amax_b')
for buffer_name in buffer_names:
buffer_value = getattr(m_fp8, buffer_name)
for init_val in (E4M3_MAX_POS, E5M2_MAX_POS):
Expand Down