|
9 | 9 | parametrize, |
10 | 10 | run_tests, |
11 | 11 | ) |
12 | | -from torchao.prototype.optim_8bit import AdamDTQ8bit, AdamWDTQ8bit |
| 12 | +from torchao.prototype import optim_8bit |
13 | 13 | from torchao.prototype.optim_8bit.subclass import quantize_8bit_with_qmap, QMAP_SIGNED |
14 | 14 | from torchao.utils import TORCH_VERSION_AFTER_2_3 |
15 | 15 |
|
@@ -50,17 +50,14 @@ def test_quantize_8bit_with_qmap_compile(self, device): |
50 | 50 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") |
51 | 51 | @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") |
52 | 52 | class TestOptim8bit(TestCase): |
53 | | - @parametrize("optim_cls,bnb_optim_cls", [ |
54 | | - (AdamDTQ8bit, bnb.optim.Adam8bit), |
55 | | - (AdamWDTQ8bit, bnb.optim.AdamW8bit), |
56 | | - ]) |
57 | | - def test_adam_8bit_correctness(self, optim_cls, bnb_optim_cls): |
| 53 | + @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) |
| 54 | + def test_adam_8bit_correctness(self, optim_name): |
58 | 55 | device = "cuda" |
59 | 56 | model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) |
60 | 57 | model2 = copy.deepcopy(model1) |
61 | 58 |
|
62 | | - optim1 = bnb_optim_cls(model1.parameters()) |
63 | | - optim2 = optim_cls(model2.parameters()) |
| 59 | + optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) |
| 60 | + optim2 = getattr(optim_8bit, optim_name)(model2.parameters()) |
64 | 61 |
|
65 | 62 | for _ in range(2): |
66 | 63 | x = torch.randn(4, 32, device=device) |
|
0 commit comments