Skip to content

[reland][ROCm] use dataclass for fnuz type setting #1150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@


from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
CastConfig,
Float8LinearConfig,
ScalingGranularity,
ScalingType,
Float8LinearRecipeName,
Expand Down Expand Up @@ -109,15 +109,15 @@ def test_split_cat(self):

def test_index_put(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
scale_a = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)

index = torch.randint(0, 15, (16,), dtype=torch.long)

b = torch.rand(16, 16, dtype=torch.bfloat16)
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)
scale_b = tensor_to_scale(b, e4m3_dtype)
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype)

with pytest.raises(AssertionError):
b[index] = fp8_a
Expand All @@ -127,8 +127,8 @@ def test_index_put(self):

def test_copy_(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
scale_a = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)

b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
Expand All @@ -137,7 +137,7 @@ def test_copy_(self):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
torch.empty(16, dtype=torch.float8_e4m3fn),
torch.empty(16, dtype=e4m3_dtype),
scale_a,
torch.bfloat16,
fp8_a._linear_mm_config,
Expand Down Expand Up @@ -332,11 +332,11 @@ def _test_linear_impl(
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize(
"scaling_type_input",
"scaling_type_input",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
"scaling_type_weight",
"scaling_type_weight",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -377,7 +377,7 @@ def test_linear_from_config_params(
# to combine with the main testing function.
# TODO(future PR): make this cleaner.
@pytest.mark.parametrize(
"recipe_name",
"recipe_name",
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down Expand Up @@ -610,7 +610,7 @@ def test_different_configs_error(self):
@pytest.mark.parametrize("use_fast_accum", [True, False])
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
torch.manual_seed(42)
input_dtype = torch.float8_e4m3fn
input_dtype = e4m3_dtype
compare_type = torch.float32

a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
Expand Down
14 changes: 7 additions & 7 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import torch
import torch.nn as nn
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
CastConfig,
Float8LinearConfig,
ScalingType,
Float8LinearRecipeName,
recipe_name_to_linear_config,
)
Expand Down Expand Up @@ -77,7 +77,7 @@ def _test_compile_base(
y_fp8.sum().backward()
y_ref = m_ref(x_ref)
y_ref.sum().backward()
# TODO(future PR): can also test fp8 eager vs compile here with a tigher
# TODO(future PR): can also test fp8 eager vs compile here with a tigher
# tolerance
torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2)
torch.testing.assert_close(
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_inductor_from_config_params(
# to combine with the main testing function.
# TODO(future PR): make this cleaner.
@pytest.mark.parametrize(
"recipe_name",
"recipe_name",
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
)
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
Expand Down Expand Up @@ -412,14 +412,14 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
)
float8_eager = hp_tensor_to_float8_dynamic(
hp_tensor1,
torch.float8_e4m3fn,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor2,
torch.float8_e4m3fn,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
Expand Down
43 changes: 30 additions & 13 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,29 @@ def __post_init__(self):
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


@dataclass
class Float8TypeConfig:
"""
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.

Currently, ROCm only supports fnuz variants.
"""

# The preferred e4m3 type.
e4m3_dtype = torch.float8_e4m3fn

# The preferred e5m2 type.
e5m2_dtype = torch.float8_e5m2

def __post_init__(self):
if torch.version.hip and torch.cuda.is_available():
prop = torch.cuda.get_device_properties(0)
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
self.e4m3_dtype = torch.float8_e4m3fnuz
self.e5m2_dtype = torch.float8_e5m2fnuz


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Expand All @@ -118,11 +141,11 @@ class Float8LinearConfig:
# Per-tensor configuration for casting of `input`, `weight`, `grad_output`
# for the operands of gemms calculating `output`, `grad_weight`, and `grad_input`.
#
# Note:
# 1. if `cast_config_input_for_grad_weight` is None, then
# Note:
# 1. if `cast_config_input_for_grad_weight` is None, then
# `cast_config_input` is used for scaling `input` for both gemms that
# use `input.
# 2. if `cast_config_input_for_grad_weight` is specified, then
# use `input.
# 2. if `cast_config_input_for_grad_weight` is specified, then
# a. `cast_config_input` is used for scaling `input` for the gemm that calculates
# `output`
# b. `cast_config_input_for_grad_weight` is used for scaling `input` for
Expand Down Expand Up @@ -240,12 +263,6 @@ def __post_init__(self):
f"incompatible operand precision for {gemm_name}"


# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
# TODO(future PR): move this to Float8LinearConfig
use_fnuz_dtype = False


# Pre-made recipes for common configurations
# TODO(future PR): go through a round of design on this, and eventually expose
# as a top level public API.
Expand All @@ -272,7 +289,7 @@ def recipe_name_to_linear_config(
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
# fast with `use_fast_accum=True`. Note that rowwise scaling is more
# accurate than tensorwise scaling, so the overall impact on accuracy
Expand Down Expand Up @@ -300,8 +317,8 @@ def recipe_name_to_linear_config(
#
# key characteristics:
# * increased accuracy for grad_weight
# * `input`, `weight` and `grad_output` now only need to be scaled
# axiswise across a single dim compared to vanilla all-axiswise,
# * `input`, `weight` and `grad_output` now only need to be scaled
# axiswise across a single dim compared to vanilla all-axiswise,
# which is more amenable to fast kernels

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
Expand Down
8 changes: 4 additions & 4 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import torch
import torch.distributed as dist

import torchao.float8.config as config
from torchao.float8.config import ScalingGranularity
from torchao.float8.config import Float8TypeConfig, ScalingGranularity

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand All @@ -29,8 +28,9 @@


# User defined type for using the individual F8 type based on config
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
type_config = Float8TypeConfig()
e4m3_dtype = type_config.e4m3_dtype
e5m2_dtype = type_config.e5m2_dtype


@torch.no_grad()
Expand Down
Loading