Skip to content

Commit b8c16fb

Browse files
committed
[Float8] Fix serialization of dynamic activation fp8
stack-info: PR: #838, branch: drisspg/stack/12
1 parent aac19a1 commit b8c16fb

File tree

2 files changed

+98
-28
lines changed

2 files changed

+98
-28
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from functools import partial
2525
from typing import Tuple
2626
from contextlib import nullcontext
27+
import io
2728

2829

2930
random.seed(0)
@@ -142,6 +143,67 @@ def test_per_row_with_float32(self):
142143
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
143144
)
144145

146+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
147+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
148+
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
149+
def test_serialization(self, mode: str):
150+
# Create and quantize the model
151+
model = ToyLinearModel(16, 32).to(device="cuda")
152+
if mode == "dynamic":
153+
factory = float8_dynamic_activation_float8_weight()
154+
else:
155+
factory = float8_weight_only()
156+
quantize_(model, factory)
157+
158+
# Save the state dict to an in-memory buffer
159+
buffer = io.BytesIO()
160+
torch.save(model.state_dict(), buffer)
161+
162+
# Reset the buffer position
163+
buffer.seek(0)
164+
165+
# Load the state dict from the buffer
166+
loaded_state_dict = torch.load(buffer)
167+
168+
# Create a new model and load the state dict
169+
with torch.device("meta"):
170+
new_model = ToyLinearModel(16, 32)
171+
new_model.load_state_dict(loaded_state_dict, assign=True)
172+
173+
# Compare the original and loaded models
174+
if mode == "weight-only":
175+
model_weight_1 = model.linear1.weight.layout_tensor.float8_data.to(
176+
torch.float32
177+
)
178+
new_model_weight_1 = new_model.linear1.weight.layout_tensor.float8_data.to(
179+
torch.float32
180+
)
181+
182+
model_weight_2 = model.linear2.weight.layout_tensor.float8_data.to(
183+
torch.float32
184+
)
185+
new_model_weight_2 = new_model.linear2.weight.layout_tensor.float8_data.to(
186+
torch.float32
187+
)
188+
189+
else:
190+
model_weight_1 = model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
191+
torch.float32
192+
)
193+
new_model_weight_1 = new_model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
194+
torch.float32
195+
)
196+
197+
model_weight_2 = model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
198+
torch.float32
199+
)
200+
new_model_weight_2 = new_model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
201+
torch.float32
202+
)
203+
204+
assert torch.allclose(model_weight_1, new_model_weight_1)
205+
assert torch.allclose(model_weight_2, new_model_weight_2)
206+
145207

146208
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
147209

torchao/quantization/quant_api.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torchao.dtypes.uintx.Uintx import UintxLayoutType
2626
from torchao.dtypes import (
2727
to_affine_quantized_intx,
28+
to_affine_quantized_floatx,
2829
TensorCoreTiledLayoutType,
2930
PlainLayoutType,
3031
AffineQuantizedTensor,
@@ -670,6 +671,35 @@ def _validate_granularity(
670671
else:
671672
raise ValueError(f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported.")
672673

674+
def _get_block_size(x: torch.Tensor, granularity: _fp8_granularities):
675+
if isinstance(granularity, PerTensor):
676+
return x.shape
677+
elif isinstance(granularity, PerRow):
678+
return (1,) * (x.dim() - 1) + (x.shape[-1],)
679+
else:
680+
raise ValueError(f"Unsupported granularity: {granularity}")
681+
682+
683+
def _input_quant_func_dyanmic_fp8(
684+
x: torch.Tensor,
685+
activation_granularity: _fp8_granularities,
686+
activation_dtype: torch.dtype,
687+
):
688+
if isinstance(activation_granularity, PerRow):
689+
assert (
690+
x.dtype == torch.bfloat16
691+
), "PerRow quantization only works for bfloat16 precision input activation"
692+
693+
block_size = _get_block_size(x, activation_granularity)
694+
activation = to_affine_quantized_floatx(
695+
input_float=x,
696+
block_size=block_size,
697+
target_dtype=activation_dtype,
698+
scale_dtype=torch.float32,
699+
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
700+
)
701+
return activation
702+
673703

674704
def float8_dynamic_activation_float8_weight(
675705
activation_dtype: torch.dtype = torch.float8_e4m3fn,
@@ -693,28 +723,18 @@ def float8_dynamic_activation_float8_weight(
693723
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
694724
695725
"""
696-
from torchao.dtypes import to_affine_quantized_floatx
697-
698726
if mm_config is None:
699727
mm_config = Float8MMConfig(use_fast_accum=True)
700728

701729
activation_granularity, weight_granularity = _validate_granularity(granularity)
702730

703-
def get_block_size(x: torch.Tensor, granularity: _fp8_granularities):
704-
if isinstance(granularity, PerTensor):
705-
return x.shape
706-
elif isinstance(granularity, PerRow):
707-
return (1,) * (x.dim() - 1) + (x.shape[-1],)
708-
else:
709-
raise ValueError(f"Unsupported granularity: {granularity}")
710-
711731
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
712732
if isinstance(weight_granularity, PerRow):
713733
assert (
714734
weight.dtype == torch.bfloat16
715735
), "PerRow quantization only works for bfloat16 precision input weight"
716736

717-
block_size = get_block_size(weight, weight_granularity)
737+
block_size = _get_block_size(weight, weight_granularity)
718738
quantized_weight = to_affine_quantized_floatx(
719739
input_float=weight,
720740
block_size=block_size,
@@ -723,23 +743,11 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
723743
layout_type=Float8LayoutType(mm_config=mm_config),
724744
)
725745

726-
def input_quant_func(x: torch.Tensor):
727-
if isinstance(activation_granularity, PerRow):
728-
assert (
729-
x.dtype == torch.bfloat16
730-
), "PerRow quantization only works for bfloat16 precision input activation"
731-
732-
block_size = get_block_size(x, activation_granularity)
733-
activation = to_affine_quantized_floatx(
734-
input_float=x,
735-
block_size=block_size,
736-
target_dtype=activation_dtype,
737-
scale_dtype=torch.float32,
738-
layout_type=Float8LayoutType(
739-
mm_config=None
740-
), # Config is stored on weight
741-
)
742-
return activation
746+
input_quant_func = partial(
747+
_input_quant_func_dyanmic_fp8,
748+
activation_granularity=activation_granularity,
749+
activation_dtype=activation_dtype,
750+
)
743751

744752
quantized_weight = to_linear_activation_quantized(
745753
quantized_weight, input_quant_func

0 commit comments

Comments
 (0)