Skip to content

WIP PR: add support for hpu in float8 base and compile test for torch ao #2326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_7,
is_sm_at_least_89,
is_sm_at_least_90,
get_device,
)

if not TORCH_VERSION_AT_LEAST_2_5:
if not TORCH_VERSION_AT_LEAST_2_7:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down Expand Up @@ -239,11 +240,10 @@ def test_axiswise_reshape(self):
(ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE),
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
@unittest.skipIf(not torch.accelerator.is_available() and not is_sm_at_least_90(), "Accelerator not available or If CUDA, it requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
a = torch.randn(*a_shape, dtype=torch.bfloat16, device=get_device())
b = torch.randn(64, 32, dtype=torch.bfloat16, device=get_device())

linear_mm_config = LinearMMConfig()

Expand Down Expand Up @@ -272,7 +272,7 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_fp8_dtype(
self,
):
Expand Down Expand Up @@ -337,7 +337,7 @@ def _test_linear_impl(
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_bias", [False, True])
@pytest.mark.parametrize("use_ac", [False, True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_linear_from_config_params(
self,
x_shape,
Expand All @@ -349,8 +349,8 @@ def test_linear_from_config_params(
linear_bias: bool,
use_ac: bool,
):
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
x = torch.randn(*x_shape, device=get_device(), dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device=get_device(), dtype=linear_dtype)

config = get_test_float8_linear_config(
scaling_type_input,
Expand Down Expand Up @@ -379,23 +379,23 @@ def test_linear_from_config_params(
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_bias", [True, False])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
@skip_if_rocm("ROCm enablement in progress")
def test_linear_from_recipe(
self,
recipe_name,
x_shape,
linear_bias: bool,
):
if torch.cuda.get_device_capability() < (9, 0):
if is_sm_at_least_90():
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()

linear_dtype = torch.bfloat16
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
x = torch.randn(*x_shape, device=get_device(), dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device=get_device(), dtype=linear_dtype)
config = Float8LinearConfig.from_recipe_name(recipe_name)
self._test_linear_impl(
x,
Expand All @@ -409,32 +409,32 @@ def test_linear_from_recipe(
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_autocast_outputs(
self,
emulate: bool,
linear_dtype: torch.dtype,
):
m_ref = nn.Sequential(
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
nn.Linear(32, 32, device=get_device(), dtype=linear_dtype),
nn.Linear(32, 32, device=get_device(), dtype=linear_dtype),
)
config = Float8LinearConfig(
emulate=emulate,
)
m = convert_to_float8_training(copy.deepcopy(m_ref), config=config)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
x = torch.randn(16, 32, device=get_device(), dtype=linear_dtype)
y = m(x)
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"

# autocast on
with torch.autocast("cuda"):
with torch.autocast(get_device(), dtype=torch.half):
y = m(x)
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"

with torch.autocast("cuda", dtype=torch.bfloat16):
with torch.autocast(get_device(), dtype=torch.bfloat16):
y = m(x)
assert y.dtype == torch.bfloat16, (
f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
Expand All @@ -446,26 +446,26 @@ def test_autocast_outputs(
@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = nn.Linear(32, 16, device=get_device(), dtype=linear_dtype)
config = Float8LinearConfig(emulate=emulate)
m = Float8Linear.from_float(copy.deepcopy(m), config)

# Cast the module to dtype
m = m.to(dtype=linear_dtype)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
x = torch.randn(16, 32, device=get_device(), dtype=linear_dtype)
y = m(x)
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"

# autocast on
with torch.autocast("cuda"):
with torch.autocast(get_device(), dtype=torch.half):
y = m(x)
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"

with torch.autocast("cuda", dtype=torch.bfloat16):
with torch.autocast(get_device(), dtype=torch.bfloat16):
y = m(x)
assert y.dtype == torch.bfloat16, (
f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
Expand All @@ -483,18 +483,18 @@ def test_repr(self):
s = m.__repr__()
assert "i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2" in s

@unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available")
@unittest.skipIf(not torch.accelerator.is_available() and not is_sm_at_least_89(), "Accelerator not available or If CUDA, it requires CUDA capability >= 8.9")
def test_inference_mode(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
x = torch.randn(32, 32, device=get_device())
m = nn.Sequential(nn.Linear(32, 32)).to(device=get_device())
m = convert_to_float8_training(m)
with torch.inference_mode(mode=True):
m(x)

@unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available")
@unittest.skipIf(not torch.accelerator.is_available() and not is_sm_at_least_89(), "Accelerator not available or If CUDA, it requires CUDA capability >= 8.9")
def test_quantize(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
x = torch.randn(32, 32, device=get_device())
m = nn.Sequential(nn.Linear(32, 32)).to(device=get_device())
m = convert_to_float8_training(m)
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
from torchao.quantization.quant_api import float8_weight_only, quantize_
Expand All @@ -509,8 +509,9 @@ def test_quantize(self):

class TestScaledMM:
@unittest.skipIf(
not torch.accelerator.is_available() and
not is_sm_at_least_89(),
"CUDA not available",
"Accelerator not available or If CUDA, it requires CUDA capability >= 8.9",
)
@pytest.mark.parametrize(
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
Expand All @@ -522,8 +523,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
output_dtype = base_dtype
compare_type = torch.float32

a = torch.randn(16, 16, device="cuda", dtype=base_dtype)
b = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
a = torch.randn(16, 16, device=get_device(), dtype=base_dtype)
b = torch.randn(32, 16, device=get_device(), dtype=base_dtype).t()

a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()
Expand Down Expand Up @@ -554,10 +555,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not is_sm_at_least_89(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available() and not is_sm_at_least_89(), "Accelerator not available or If CUDA, it requires CUDA capability >= 8.9")
def test_different_configs_error(self):
x_fp32 = torch.randn(16, 16, device="cuda")
x_scale = torch.tensor(1.0, device="cuda")
x_fp32 = torch.randn(16, 16, device=get_device())
x_scale = torch.tensor(1.0, device=get_device())
fp8_dtype = e4m3_dtype
linear_config_a = LinearMMConfig(
ScaledMMConfig(False, True, False, False),
Expand Down Expand Up @@ -590,8 +591,9 @@ def test_different_configs_error(self):
a @ b

@unittest.skipIf(
not torch.accelerator.is_available() and
not is_sm_at_least_89(),
"CUDA not available",
"Accelerator not available or If CUDA, it requires CUDA capability >= 8.9",
)
@pytest.mark.parametrize(
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
Expand All @@ -602,8 +604,8 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
input_dtype = e4m3_dtype
compare_type = torch.float32

a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
b = torch.randn(41, 128, device="cuda", dtype=base_dtype)
a = torch.randn(16, 41, device=get_device(), dtype=base_dtype)
b = torch.randn(41, 128, device=get_device(), dtype=base_dtype)

a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()
Expand Down Expand Up @@ -681,7 +683,7 @@ class TestNumerics:
torch.float8_e5m2fnuz,
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_small_amax_float16(self, float8_dtype):
# If we calculate scale naively with FP8_MAX_POS / amax,
# the result may not be representable in fp16. Verify that
Expand All @@ -700,7 +702,7 @@ def test_small_amax_float16(self, float8_dtype):
FP16_MAX_POS = torch.finfo(torch.float16).max

target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12)
x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")
x = torch.tensor([target_amax], dtype=torch.float16, device=get_device())
scale = tensor_to_scale(x, float8_dtype)
assert not torch.any(torch.isinf(scale))

Expand Down Expand Up @@ -834,3 +836,4 @@ def test_fp8_tensor_statistics(self):

if __name__ == "__main__":
pytest.main([__file__])

Loading