1
+ import copy
1
2
import os
3
+ import pickle
2
4
from tempfile import TemporaryDirectory
3
5
4
6
import pytest
8
10
from tests .helpers import TRUE_FALSE
9
11
10
12
storage = {
11
- ' uint8' : torch .uint8 ,
12
- ' float16' : torch .float16 ,
13
- ' bfloat16' : torch .bfloat16 ,
14
- ' float32' : torch .float32
13
+ " uint8" : torch .uint8 ,
14
+ " float16" : torch .float16 ,
15
+ " bfloat16" : torch .bfloat16 ,
16
+ " float32" : torch .float32 ,
15
17
}
16
18
17
- @pytest .mark .parametrize ("quant_storage" , ['uint8' , 'float16' , 'bfloat16' , 'float32' ])
19
+
20
+ @pytest .mark .parametrize ("quant_storage" , ["uint8" , "float16" , "bfloat16" , "float32" ])
18
21
@pytest .mark .parametrize ("bias" , TRUE_FALSE )
19
22
@pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE )
20
23
@pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
@@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
24
27
device = "cuda"
25
28
layer_shape = (300 , 400 )
26
29
27
- linear = torch .nn .Linear (* layer_shape , dtype = original_dtype , device = "cpu" ) # original layer
30
+ linear = torch .nn .Linear (
31
+ * layer_shape , dtype = original_dtype , device = "cpu"
32
+ ) # original layer
28
33
29
34
# Quantizing original layer
30
35
linear_q = bnb .nn .Linear4bit (
@@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
36
41
quant_type = quant_type ,
37
42
device = "meta" ,
38
43
)
39
- new_weight = bnb .nn .Params4bit (data = linear .weight , quant_type = quant_type , requires_grad = False )
44
+ new_weight = bnb .nn .Params4bit (
45
+ data = linear .weight , quant_type = quant_type , requires_grad = False
46
+ )
40
47
linear_q .weight = new_weight
41
48
if bias :
42
49
linear_q .bias = torch .nn .Parameter (linear .bias )
@@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
80
87
quant_storage = storage [quant_storage ],
81
88
device = "meta" ,
82
89
)
83
- linear_qs .weight = bnb .nn .Params4bit (data = linear .weight , requires_grad = False , quant_type = quant_type , quant_storage = storage [quant_storage ])
90
+ linear_qs .weight = bnb .nn .Params4bit (
91
+ data = linear .weight ,
92
+ requires_grad = False ,
93
+ quant_type = quant_type ,
94
+ quant_storage = storage [quant_storage ],
95
+ )
84
96
if bias :
85
97
linear_qs .bias = torch .nn .Parameter (linear .bias )
86
98
linear_qs = linear_qs .to (device )
@@ -91,15 +103,15 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
91
103
92
104
q0 = a .quant_state
93
105
q1 = b .quant_state
94
- for attr in (' code' , ' dtype' , ' blocksize' , ' absmax' ):
106
+ for attr in (" code" , " dtype" , " blocksize" , " absmax" ):
95
107
c , d = getattr (q0 , attr ), getattr (q1 , attr )
96
108
if isinstance (c , torch .Tensor ):
97
109
assert torch .equal (c , d )
98
110
else :
99
111
assert c == d , f"{ c } != { d } "
100
112
101
113
if q0 .state2 is not None :
102
- for attr in (' code' , ' dtype' , ' blocksize' , ' absmax' ):
114
+ for attr in (" code" , " dtype" , " blocksize" , " absmax" ):
103
115
c , d = getattr (q0 .state2 , attr ), getattr (q1 .state2 , attr )
104
116
if isinstance (c , torch .Tensor ):
105
117
assert torch .equal (c , d )
@@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
125
137
assert torch .equal (a , c )
126
138
127
139
# Test moving to CPU and back to GPU
128
- linear_q2 .to (' cpu' )
140
+ linear_q2 .to (" cpu" )
129
141
linear_q2 .to (device )
130
142
d = linear_qs (x )
131
143
assert c .dtype == d .dtype
@@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
139
151
torch .save (linear .state_dict (), state_path )
140
152
torch .save (linear_q .state_dict (), state_path_4bit )
141
153
142
- size_orig , size_4 = os .path .getsize (state_path ), os .path .getsize (
143
- state_path_4bit
154
+ size_orig , size_4 = (
155
+ os .path .getsize (state_path ),
156
+ os .path .getsize (state_path_4bit ),
144
157
)
145
158
size_ratio = size_4 / size_orig
146
- target_compression = 0.143 if original_dtype == torch .float32 else 0.29 # these numbers get lower as weight shape increases
159
+ target_compression = (
160
+ 0.143 if original_dtype == torch .float32 else 0.29
161
+ ) # these numbers get lower as weight shape increases
147
162
ratio_error_msg = f"quantized_size { size_4 :,} is larger on disk than { target_compression :.2%} of original size { size_orig :,} "
148
163
assert size_ratio < target_compression , ratio_error_msg
164
+
165
+
166
+ def test_copy_param ():
167
+ tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ])
168
+ param = bnb .nn .Params4bit (data = tensor , requires_grad = False ).cuda (0 )
169
+
170
+ shallow_copy_param = copy .copy (param )
171
+ assert param .quant_state is shallow_copy_param .quant_state
172
+ assert param .data .data_ptr () == shallow_copy_param .data .data_ptr ()
173
+
174
+
175
+ def test_deepcopy_param ():
176
+ tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ])
177
+ param = bnb .nn .Params4bit (data = tensor , requires_grad = False ).cuda (0 )
178
+ copy_param = copy .deepcopy (param )
179
+ assert param .quant_state is not copy_param .quant_state
180
+ assert param .data .data_ptr () != copy_param .data .data_ptr ()
181
+
182
+
183
+ def test_params4bit_real_serialization ():
184
+ original_tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ], dtype = torch .float32 )
185
+ original_param = bnb .nn .Params4bit (data = original_tensor , quant_type = "fp4" )
186
+
187
+ original_param .cuda (0 ) # move to CUDA to trigger quantization
188
+
189
+ serialized_param = pickle .dumps (original_param )
190
+ deserialized_param = pickle .loads (serialized_param )
191
+
192
+ assert torch .equal (original_param .data , deserialized_param .data )
193
+ assert original_param .requires_grad == deserialized_param .requires_grad == False
194
+ assert original_param .quant_type == deserialized_param .quant_type
195
+ assert original_param .blocksize == deserialized_param .blocksize
196
+ assert original_param .compress_statistics == deserialized_param .compress_statistics
197
+ assert original_param .quant_state == deserialized_param .quant_state
0 commit comments