Skip to content

Commit cfd6ac7

Browse files
SunMarcakxTitus-von-Koeller
authored
add deepcopy and copy for Param4bit (#1060)
* fix deepcopy and copy * add tests * remove line * ruff fix * ruff * Update tests/test_linear4bit.py Co-authored-by: Aarni Koskela <akx@iki.fi> * add missing state * ruff format * ignore formatting commit for git blame * Params4bit should be initialized as frozen by default * add test for serialization round-tripping * add comparison capability for QuantSate * add back accidentally remove line --------- Co-authored-by: Aarni Koskela <akx@iki.fi> Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
1 parent b0730f4 commit cfd6ac7

File tree

4 files changed

+121
-17
lines changed

4 files changed

+121
-17
lines changed

.git-blame-ignore-revs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848
66

77
# Remove f-prefix from strings that don't use formatting
88
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6
9+
10+
# format tests/linear_4bit.py
11+
34735ba89de8235ea9da6ef409f814dcea9e2038

bitsandbytes/functional.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,21 @@ def to(self, device):
706706
self.state2.absmax = self.state2.absmax.to(device)
707707
self.state2.code = self.state2.code.to(device)
708708

709+
def __eq__(self, other):
710+
if not isinstance(other, QuantState):
711+
return False
712+
713+
return (
714+
torch.allclose(self.absmax, other.absmax, atol=1e-6) and
715+
self.shape == other.shape and
716+
torch.allclose(self.code, other.code, atol=1e-6) and
717+
self.dtype == other.dtype and
718+
self.blocksize == other.blocksize and
719+
self.quant_type == other.quant_type and
720+
(self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and
721+
(self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2)
722+
)
723+
709724

710725
def quantize_blockwise(
711726
A: Tensor,

bitsandbytes/nn/modules.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
import copy
56
from typing import Any, Dict, Optional, TypeVar, Union, overload
67
import warnings
78

@@ -191,7 +192,7 @@ class Params4bit(torch.nn.Parameter):
191192
def __new__(
192193
cls,
193194
data: Optional[torch.Tensor] = None,
194-
requires_grad=True,
195+
requires_grad=False, # quantized weights should be frozen by default
195196
quant_state: Optional[QuantState] = None,
196197
blocksize: int = 64,
197198
compress_statistics: bool = True,
@@ -214,6 +215,37 @@ def __new__(
214215
self.module = module
215216
return self
216217

218+
def __getstate__(self):
219+
state = self.__dict__
220+
state["data"] = self.data
221+
state["requires_grad"] = self.requires_grad
222+
return state
223+
224+
def __setstate__(self, state):
225+
self.requires_grad = state["requires_grad"]
226+
self.blocksize = state["blocksize"]
227+
self.compress_statistics = state["compress_statistics"]
228+
self.quant_type = state["quant_type"]
229+
self.quant_state = state["quant_state"]
230+
self.data = state["data"]
231+
self.quant_storage = state["quant_storage"]
232+
self.bnb_quantized = state["bnb_quantized"]
233+
self.module = state["module"]
234+
235+
def __deepcopy__(self,memo):
236+
new_instance = type(self).__new__(type(self))
237+
state = self.__getstate__()
238+
new_instance.__setstate__(state)
239+
new_instance.quant_state = copy.deepcopy(state["quant_state"])
240+
new_instance.data = copy.deepcopy(state["data"])
241+
return new_instance
242+
243+
def __copy__(self):
244+
new_instance = type(self).__new__(type(self))
245+
state = self.__getstate__()
246+
new_instance.__setstate__(state)
247+
return new_instance
248+
217249
@classmethod
218250
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
219251
self = torch.Tensor._make_subclass(cls, data.to(device))
@@ -227,8 +259,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any],
227259

228260
def _quantize(self, device):
229261
w = self.data.contiguous().cuda(device)
230-
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
231-
quant_type=self.quant_type, quant_storage=self.quant_storage)
262+
w_4bit, quant_state = bnb.functional.quantize_4bit(
263+
w,
264+
blocksize=self.blocksize,
265+
compress_statistics=self.compress_statistics,
266+
quant_type=self.quant_type,
267+
quant_storage=self.quant_storage,
268+
)
232269
self.data = w_4bit
233270
self.quant_state = quant_state
234271
if self.module is not None:

tests/test_linear4bit.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import copy
12
import os
3+
import pickle
24
from tempfile import TemporaryDirectory
35

46
import pytest
@@ -8,13 +10,14 @@
810
from tests.helpers import TRUE_FALSE
911

1012
storage = {
11-
'uint8': torch.uint8,
12-
'float16': torch.float16,
13-
'bfloat16': torch.bfloat16,
14-
'float32': torch.float32
13+
"uint8": torch.uint8,
14+
"float16": torch.float16,
15+
"bfloat16": torch.bfloat16,
16+
"float32": torch.float32,
1517
}
1618

17-
@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32'])
19+
20+
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
1821
@pytest.mark.parametrize("bias", TRUE_FALSE)
1922
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
2023
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
2427
device = "cuda"
2528
layer_shape = (300, 400)
2629

27-
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
30+
linear = torch.nn.Linear(
31+
*layer_shape, dtype=original_dtype, device="cpu"
32+
) # original layer
2833

2934
# Quantizing original layer
3035
linear_q = bnb.nn.Linear4bit(
@@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
3641
quant_type=quant_type,
3742
device="meta",
3843
)
39-
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
44+
new_weight = bnb.nn.Params4bit(
45+
data=linear.weight, quant_type=quant_type, requires_grad=False
46+
)
4047
linear_q.weight = new_weight
4148
if bias:
4249
linear_q.bias = torch.nn.Parameter(linear.bias)
@@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
8087
quant_storage=storage[quant_storage],
8188
device="meta",
8289
)
83-
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
90+
linear_qs.weight = bnb.nn.Params4bit(
91+
data=linear.weight,
92+
requires_grad=False,
93+
quant_type=quant_type,
94+
quant_storage=storage[quant_storage],
95+
)
8496
if bias:
8597
linear_qs.bias = torch.nn.Parameter(linear.bias)
8698
linear_qs = linear_qs.to(device)
@@ -91,15 +103,15 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
91103

92104
q0 = a.quant_state
93105
q1 = b.quant_state
94-
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
106+
for attr in ("code", "dtype", "blocksize", "absmax"):
95107
c, d = getattr(q0, attr), getattr(q1, attr)
96108
if isinstance(c, torch.Tensor):
97109
assert torch.equal(c, d)
98110
else:
99111
assert c == d, f"{c} != {d}"
100112

101113
if q0.state2 is not None:
102-
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
114+
for attr in ("code", "dtype", "blocksize", "absmax"):
103115
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
104116
if isinstance(c, torch.Tensor):
105117
assert torch.equal(c, d)
@@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
125137
assert torch.equal(a, c)
126138

127139
# Test moving to CPU and back to GPU
128-
linear_q2.to('cpu')
140+
linear_q2.to("cpu")
129141
linear_q2.to(device)
130142
d = linear_qs(x)
131143
assert c.dtype == d.dtype
@@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
139151
torch.save(linear.state_dict(), state_path)
140152
torch.save(linear_q.state_dict(), state_path_4bit)
141153

142-
size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
143-
state_path_4bit
154+
size_orig, size_4 = (
155+
os.path.getsize(state_path),
156+
os.path.getsize(state_path_4bit),
144157
)
145158
size_ratio = size_4 / size_orig
146-
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
159+
target_compression = (
160+
0.143 if original_dtype == torch.float32 else 0.29
161+
) # these numbers get lower as weight shape increases
147162
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
148163
assert size_ratio < target_compression, ratio_error_msg
164+
165+
166+
def test_copy_param():
167+
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
168+
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
169+
170+
shallow_copy_param = copy.copy(param)
171+
assert param.quant_state is shallow_copy_param.quant_state
172+
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
173+
174+
175+
def test_deepcopy_param():
176+
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
177+
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
178+
copy_param = copy.deepcopy(param)
179+
assert param.quant_state is not copy_param.quant_state
180+
assert param.data.data_ptr() != copy_param.data.data_ptr()
181+
182+
183+
def test_params4bit_real_serialization():
184+
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
185+
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
186+
187+
original_param.cuda(0) # move to CUDA to trigger quantization
188+
189+
serialized_param = pickle.dumps(original_param)
190+
deserialized_param = pickle.loads(serialized_param)
191+
192+
assert torch.equal(original_param.data, deserialized_param.data)
193+
assert original_param.requires_grad == deserialized_param.requires_grad == False
194+
assert original_param.quant_type == deserialized_param.quant_type
195+
assert original_param.blocksize == deserialized_param.blocksize
196+
assert original_param.compress_statistics == deserialized_param.compress_statistics
197+
assert original_param.quant_state == deserialized_param.quant_state

0 commit comments

Comments
 (0)