1
1
from torchao .utils import (
2
2
TORCH_VERSION_AT_LEAST_2_5 ,
3
- unwrap_tensor_subclass ,
4
3
)
5
4
import pytest
6
5
7
6
if not TORCH_VERSION_AT_LEAST_2_5 :
8
7
pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
9
8
10
- from numpy import full
11
- from torch .testing ._internal .common_utils import (
12
- run_tests ,
13
- )
14
9
from torch ._inductor .test_case import TestCase as InductorTestCase
15
10
from torch .testing ._internal import common_utils
16
- from torch ._dynamo .testing import CompileCounterWithBackend
17
11
18
12
from torchao .quantization import (
19
13
quantize_ ,
20
14
float8_weight_only ,
21
15
float8_dynamic_activation_float8_weight ,
22
16
)
17
+ from torchao .quantization .observer import PerTensor , PerRow
23
18
from torchao .float8 .float8_utils import compute_error
24
19
import torch
25
20
import unittest
26
21
import pytest
27
- import tempfile
28
22
import copy
29
23
import random
30
-
31
- from unittest .mock import patch
24
+ from functools import partial
25
+ from typing import Tuple
26
+ from contextlib import nullcontext
27
+ import io
32
28
33
29
34
30
random .seed (0 )
@@ -56,6 +52,9 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
56
52
@common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
57
53
@common_utils .parametrize ("mode" , ["dynamic" , "weight-only" ])
58
54
@common_utils .parametrize ("compile" , [True , False ])
55
+ @common_utils .parametrize (
56
+ "granularity" , [PerTensor (), PerRow ()] if is_H100 else [PerTensor ()]
57
+ )
59
58
# Inputs are (M,..), K, N
60
59
@common_utils .parametrize (
61
60
"sizes" ,
@@ -68,33 +67,142 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
68
67
],
69
68
)
70
69
def test_fp8_linear_variants (
71
- self , dtype : torch .dtype , mode : str , compile : bool , sizes : tuple
70
+ self , dtype : torch .dtype , mode : str , compile : bool , sizes : Tuple , granularity
72
71
):
73
- M , N , K = sizes
74
- input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
75
-
76
- mode_map = {
77
- "dynamic" : float8_dynamic_activation_float8_weight ,
78
- "weight-only" : float8_weight_only ,
79
- }
72
+ raises = (
73
+ isinstance (granularity , PerRow )
74
+ and mode == "dynamic"
75
+ and dtype != torch .bfloat16
76
+ )
77
+ context = (
78
+ nullcontext ()
79
+ if not raises
80
+ else pytest .raises (
81
+ AssertionError ,
82
+ match = "PerRow quantization only works for bfloat16 precision" ,
83
+ )
84
+ )
85
+ with context :
86
+ M , N , K = sizes
87
+ input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
88
+
89
+ mode_map = {
90
+ "dynamic" : partial (
91
+ float8_dynamic_activation_float8_weight , granularity = granularity
92
+ ),
93
+ "weight-only" : float8_weight_only ,
94
+ }
95
+
96
+ # Create a linear layer with bfloat16 dtype
97
+ model = ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
98
+
99
+ quantized_model = copy .deepcopy (model )
100
+ factory = mode_map [mode ]()
101
+ quantize_ (model , factory )
102
+
103
+ if compile :
104
+ quantized_model = torch .compile (quantized_model , fullgraph = True )
105
+
106
+ output_original = model (input_tensor )
107
+ output_quantized = quantized_model (input_tensor )
108
+
109
+ error = compute_error (output_original , output_quantized )
110
+ assert (
111
+ compute_error (output_original , output_quantized ) > 20
112
+ ), f"Quantization error is too high got a SQNR of { error } "
113
+
114
+ def test_invalid_granularity (self ):
115
+ with pytest .raises (ValueError , match = "Invalid granularity specification" ):
116
+ float8_dynamic_activation_float8_weight (granularity = "invalid" )
117
+
118
+ def test_mismatched_granularity (self ):
119
+ with pytest .raises (
120
+ ValueError ,
121
+ match = "Different granularities for activation and weight are not supported" ,
122
+ ):
123
+ float8_dynamic_activation_float8_weight (granularity = (PerTensor (), PerRow ()))
124
+
125
+ def test_unsupported_granularity (self ):
126
+ class UnsupportedGranularity :
127
+ pass
128
+
129
+ with pytest .raises (ValueError , match = "Invalid granularity types" ):
130
+ float8_dynamic_activation_float8_weight (
131
+ granularity = (UnsupportedGranularity (), UnsupportedGranularity ())
132
+ )
80
133
81
- # Create a linear layer with bfloat16 dtype
82
- model = ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
134
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
135
+ @unittest .skipIf (not is_cuda_8_9 , "Requires GPU with compute capability >= 8.9" )
136
+ def test_per_row_with_float32 (self ):
137
+ with pytest .raises (
138
+ AssertionError ,
139
+ match = "PerRow quantization only works for bfloat16 precision" ,
140
+ ):
141
+ model = ToyLinearModel (64 , 64 ).eval ().to (torch .float32 ).to ("cuda" )
142
+ quantize_ (
143
+ model , float8_dynamic_activation_float8_weight (granularity = PerRow ())
144
+ )
83
145
84
- quantized_model = copy .deepcopy (model )
85
- factory = mode_map [mode ]()
146
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
147
+ @unittest .skipIf (not is_cuda_8_9 , "Requires GPU with compute capability >= 8.9" )
148
+ @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" ])
149
+ def test_serialization (self , mode : str ):
150
+ # Create and quantize the model
151
+ model = ToyLinearModel (16 , 32 ).to (device = "cuda" )
152
+ if mode == "dynamic" :
153
+ factory = float8_dynamic_activation_float8_weight ()
154
+ else :
155
+ factory = float8_weight_only ()
86
156
quantize_ (model , factory )
87
157
88
- if compile :
89
- quantized_model = torch .compile (quantized_model , fullgraph = True )
90
-
91
- output_original = model (input_tensor )
92
- output_quantized = quantized_model (input_tensor )
93
-
94
- error = compute_error (output_original , output_quantized )
95
- assert (
96
- compute_error (output_original , output_quantized ) > 20
97
- ), f"Quantization error is too high got a SQNR of { error } "
158
+ # Save the state dict to an in-memory buffer
159
+ buffer = io .BytesIO ()
160
+ torch .save (model .state_dict (), buffer )
161
+
162
+ # Reset the buffer position
163
+ buffer .seek (0 )
164
+
165
+ # Load the state dict from the buffer
166
+ loaded_state_dict = torch .load (buffer )
167
+
168
+ # Create a new model and load the state dict
169
+ with torch .device ("meta" ):
170
+ new_model = ToyLinearModel (16 , 32 )
171
+ new_model .load_state_dict (loaded_state_dict , assign = True )
172
+
173
+ # Compare the original and loaded models
174
+ if mode == "weight-only" :
175
+ model_weight_1 = model .linear1 .weight .layout_tensor .float8_data .to (
176
+ torch .float32
177
+ )
178
+ new_model_weight_1 = new_model .linear1 .weight .layout_tensor .float8_data .to (
179
+ torch .float32
180
+ )
181
+
182
+ model_weight_2 = model .linear2 .weight .layout_tensor .float8_data .to (
183
+ torch .float32
184
+ )
185
+ new_model_weight_2 = new_model .linear2 .weight .layout_tensor .float8_data .to (
186
+ torch .float32
187
+ )
188
+
189
+ else :
190
+ model_weight_1 = model .linear1 .weight .original_weight_tensor .layout_tensor .float8_data .to (
191
+ torch .float32
192
+ )
193
+ new_model_weight_1 = new_model .linear1 .weight .original_weight_tensor .layout_tensor .float8_data .to (
194
+ torch .float32
195
+ )
196
+
197
+ model_weight_2 = model .linear2 .weight .original_weight_tensor .layout_tensor .float8_data .to (
198
+ torch .float32
199
+ )
200
+ new_model_weight_2 = new_model .linear2 .weight .original_weight_tensor .layout_tensor .float8_data .to (
201
+ torch .float32
202
+ )
203
+
204
+ assert torch .allclose (model_weight_1 , new_model_weight_1 )
205
+ assert torch .allclose (model_weight_2 , new_model_weight_2 )
98
206
99
207
100
208
common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
0 commit comments