Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deepcopy and copy for Param4bit #1060

Merged
Prev Previous commit
Next Next commit
ruff format
  • Loading branch information
Titus-von-Koeller committed Feb 21, 2024
commit 34735ba89de8235ea9da6ef409f814dcea9e2038
46 changes: 30 additions & 16 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from tests.helpers import TRUE_FALSE

storage = {
'uint8': torch.uint8,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float32': torch.float32
"uint8": torch.uint8,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}

@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32'])

@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
Expand All @@ -25,7 +26,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device = "cuda"
layer_shape = (300, 400)

linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
linear = torch.nn.Linear(
*layer_shape, dtype=original_dtype, device="cpu"
) # original layer

# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
Expand All @@ -37,7 +40,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type=quant_type,
device="meta",
)
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
new_weight = bnb.nn.Params4bit(
data=linear.weight, quant_type=quant_type, requires_grad=False
)
linear_q.weight = new_weight
if bias:
linear_q.bias = torch.nn.Parameter(linear.bias)
Expand Down Expand Up @@ -81,7 +86,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_storage=storage[quant_storage],
device="meta",
)
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
linear_qs.weight = bnb.nn.Params4bit(
data=linear.weight,
requires_grad=False,
quant_type=quant_type,
quant_storage=storage[quant_storage],
)
if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device)
Expand All @@ -92,15 +102,15 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora

q0 = a.quant_state
q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0, attr), getattr(q1, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"

if q0.state2 is not None:
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
Expand All @@ -126,7 +136,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert torch.equal(a, c)

# Test moving to CPU and back to GPU
linear_q2.to('cpu')
linear_q2.to("cpu")
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
Expand All @@ -140,14 +150,18 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
torch.save(linear.state_dict(), state_path)
torch.save(linear_q.state_dict(), state_path_4bit)

size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
state_path_4bit
size_orig, size_4 = (
os.path.getsize(state_path),
os.path.getsize(state_path_4bit),
)
size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
target_compression = (
0.143 if original_dtype == torch.float32 else 0.29
) # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg



def test_copy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
Expand All @@ -162,4 +176,4 @@ def test_deepcopy_param():
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
copy_param = copy.deepcopy(param)
assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr()
assert param.data.data_ptr() != copy_param.data.data_ptr()