Skip to content

Commit bf11c2e

Browse files
committed
select block size based on bnb version
1 parent ba083ea commit bf11c2e

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import torch
6+
from packaging.version import Version
67
from torch import nn
78
from torch.testing._internal.common_utils import (
89
TestCase,
@@ -105,8 +106,11 @@ def test_optim_8bit_correctness(self, optim_name):
105106
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
106107
model2 = copy.deepcopy(model1)
107108

109+
# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
110+
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048
111+
108112
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
109-
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
113+
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
110114

111115
for _ in range(2):
112116
x = torch.randn(4, 32, device=device)

0 commit comments

Comments
 (0)