Skip to content

Commit 3a7bcb2

Browse files
committed
[Float8] Fix serialization of dynamic activation fp8
1 parent aac19a1 commit 3a7bcb2

File tree

2 files changed

+88
-17
lines changed

2 files changed

+88
-17
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from functools import partial
2525
from typing import Tuple
2626
from contextlib import nullcontext
27+
import io
2728

2829

2930
random.seed(0)
@@ -142,6 +143,67 @@ def test_per_row_with_float32(self):
142143
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
143144
)
144145

146+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
147+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
148+
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
149+
def test_serialization(self, mode: str):
150+
# Create and quantize the model
151+
model = ToyLinearModel(16, 32).to(device="cuda")
152+
if mode == "dynamic":
153+
factory = float8_dynamic_activation_float8_weight()
154+
else:
155+
factory = float8_weight_only()
156+
quantize_(model, factory)
157+
158+
# Save the state dict to an in-memory buffer
159+
buffer = io.BytesIO()
160+
torch.save(model.state_dict(), buffer)
161+
162+
# Reset the buffer position
163+
buffer.seek(0)
164+
165+
# Load the state dict from the buffer
166+
loaded_state_dict = torch.load(buffer)
167+
168+
# Create a new model and load the state dict
169+
with torch.device("meta"):
170+
new_model = ToyLinearModel(16, 32)
171+
new_model.load_state_dict(loaded_state_dict, assign=True)
172+
173+
# Compare the original and loaded models
174+
if mode == "weight-only":
175+
model_weight_1 = model.linear1.weight.layout_tensor.float8_data.to(
176+
torch.float32
177+
)
178+
new_model_weight_1 = new_model.linear1.weight.layout_tensor.float8_data.to(
179+
torch.float32
180+
)
181+
182+
model_weight_2 = model.linear2.weight.layout_tensor.float8_data.to(
183+
torch.float32
184+
)
185+
new_model_weight_2 = new_model.linear2.weight.layout_tensor.float8_data.to(
186+
torch.float32
187+
)
188+
189+
else:
190+
model_weight_1 = model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
191+
torch.float32
192+
)
193+
new_model_weight_1 = new_model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
194+
torch.float32
195+
)
196+
197+
model_weight_2 = model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
198+
torch.float32
199+
)
200+
new_model_weight_2 = new_model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
201+
torch.float32
202+
)
203+
204+
assert torch.allclose(model_weight_1, new_model_weight_1)
205+
assert torch.allclose(model_weight_2, new_model_weight_2)
206+
145207

146208
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
147209

torchao/quantization/quant_api.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,27 @@ def _validate_granularity(
671671
raise ValueError(f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported.")
672672

673673

674+
def _input_quant_func_dyanmic_fp8(
675+
x: torch.Tensor,
676+
activation_granularity: _fp8_granularities,
677+
activation_dtype: torch.dtype,
678+
):
679+
if isinstance(activation_granularity, PerRow):
680+
assert (
681+
x.dtype == torch.bfloat16
682+
), "PerRow quantization only works for bfloat16 precision input activation"
683+
684+
block_size = get_block_size(x, activation_granularity)
685+
activation = to_affine_quantized_floatx(
686+
input_float=x,
687+
block_size=block_size,
688+
target_dtype=activation_dtype,
689+
scale_dtype=torch.float32,
690+
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
691+
)
692+
return activation
693+
694+
674695
def float8_dynamic_activation_float8_weight(
675696
activation_dtype: torch.dtype = torch.float8_e4m3fn,
676697
weight_dtype: torch.dtype = torch.float8_e4m3fn,
@@ -723,23 +744,11 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
723744
layout_type=Float8LayoutType(mm_config=mm_config),
724745
)
725746

726-
def input_quant_func(x: torch.Tensor):
727-
if isinstance(activation_granularity, PerRow):
728-
assert (
729-
x.dtype == torch.bfloat16
730-
), "PerRow quantization only works for bfloat16 precision input activation"
731-
732-
block_size = get_block_size(x, activation_granularity)
733-
activation = to_affine_quantized_floatx(
734-
input_float=x,
735-
block_size=block_size,
736-
target_dtype=activation_dtype,
737-
scale_dtype=torch.float32,
738-
layout_type=Float8LayoutType(
739-
mm_config=None
740-
), # Config is stored on weight
741-
)
742-
return activation
747+
input_quant_func = partial(
748+
_input_quant_func_dyanmic_fp8,
749+
activation_granularity=activation_granularity,
750+
activation_dtype=activation_granularity,
751+
)
743752

744753
quantized_weight = to_linear_activation_quantized(
745754
quantized_weight, input_quant_func

0 commit comments

Comments
 (0)