Skip to content

Commit bb436e6

Browse files
committed
skip test for pytorch < 2.3
1 parent 1f8eff0 commit bb436e6

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

benchmarks/benchmark_adam_8bit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def evaluate_model(model, args):
197197

198198
step += 1
199199

200-
if args.profile and step == 50:
200+
if args.profile and step == 20:
201201
break
202202

203203
if args.profile:

test/prototype/test_optim_8bit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from torchao.prototype.optim_8bit import AdamDTQ8bit, AdamWDTQ8bit
1313
from torchao.prototype.optim_8bit.subclass import quantize_8bit_with_qmap, QMAP_SIGNED
14+
from torchao.utils import TORCH_VERSION_AFTER_2_3
1415

1516
try:
1617
import bitsandbytes as bnb
@@ -45,9 +46,10 @@ def test_quantize_8bit_with_qmap_compile(self, device):
4546
torch.testing.assert_close(actual_scale, expected_scale)
4647

4748

49+
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
50+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
51+
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
4852
class TestOptim8bit(TestCase):
49-
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
50-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
5153
@parametrize("optim_cls,bnb_optim_cls", [
5254
(AdamDTQ8bit, bnb.optim.Adam8bit),
5355
(AdamWDTQ8bit, bnb.optim.AdamW8bit),

torchao/prototype/optim_8bit/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ You can also change quantization block size (default 2048) by passing `block_siz
1717

1818
**Other optimizers**: AdamW is also available as `AdamWDTQ8bit`.
1919

20+
NOTE: this requires PyTorch >= 2.3
21+
2022
## Benchmarks
2123

2224
Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_adam_8bit.py](../../../benchmarks/benchmark_adam_8bit.py).

0 commit comments

Comments
 (0)