Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ def bmm_fp8(
>>> import flashinfer
>>> def to_float8(x, dtype=torch.float8_e4m3fn):
... finfo = torch.finfo(dtype)
... abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12)
... scale = finfo.max / abs_max
... min_val, max_val = x.aminmax()
... amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
... scale = finfo.max / amax
... x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
... return x_scl_sat.to(dtype), scale.float().reciprocal()
>>>
Expand Down
9 changes: 5 additions & 4 deletions python/tests/test_bmm_fp8.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pytest
import torch
import torch.nn.functional as F

from flashinfer import bmm_fp8


def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12)
scale = finfo.max / abs_max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()

Expand All @@ -32,9 +34,8 @@ def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype):
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res)

reference = torch.bmm(input, mat2)

cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
assert cos_sim > 0.98
assert cos_sim > 0.99


if __name__ == "__main__":
Expand Down