Skip to content

Commit d86ec5e

Browse files
committed
rename to more user-friendly
1 parent bb436e6 commit d86ec5e

File tree

6 files changed

+16
-19
lines changed

6 files changed

+16
-19
lines changed

benchmarks/benchmark_adam_8bit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# --amp bf16 \
77
# --optim Adam
88
#
9-
# To use bnb 8-bit optimizer, set --optim AdamBnb8bit. To use 8-bit optimizer implemented in torchao, set --optim AdamDTQ8bit
9+
# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo
1010
# To profile and export chrome trace, set --profile
1111
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler
1212

@@ -25,7 +25,7 @@
2525
from torchvision.transforms import v2
2626
from tqdm import tqdm
2727

28-
from torchao.prototype.optim_8bit import AdamDTQ8bit
28+
from torchao.prototype.optim_8bit import Adam8bit
2929

3030

3131
class CosineSchedule:
@@ -161,8 +161,8 @@ def evaluate_model(model, args):
161161

162162
OPTIM_MAP = dict(
163163
Adam=torch.optim.Adam,
164-
AdamBnb8bit=bnb.optim.Adam8bit,
165-
AdamDTQ8bit=AdamDTQ8bit,
164+
Adam8bitBnb=bnb.optim.Adam8bit,
165+
Adam8bitAo=Adam8bit,
166166
)
167167
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
168168
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)

test/prototype/test_optim_8bit.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
parametrize,
1010
run_tests,
1111
)
12-
from torchao.prototype.optim_8bit import AdamDTQ8bit, AdamWDTQ8bit
12+
from torchao.prototype import optim_8bit
1313
from torchao.prototype.optim_8bit.subclass import quantize_8bit_with_qmap, QMAP_SIGNED
1414
from torchao.utils import TORCH_VERSION_AFTER_2_3
1515

@@ -50,17 +50,14 @@ def test_quantize_8bit_with_qmap_compile(self, device):
5050
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
5151
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
5252
class TestOptim8bit(TestCase):
53-
@parametrize("optim_cls,bnb_optim_cls", [
54-
(AdamDTQ8bit, bnb.optim.Adam8bit),
55-
(AdamWDTQ8bit, bnb.optim.AdamW8bit),
56-
])
57-
def test_adam_8bit_correctness(self, optim_cls, bnb_optim_cls):
53+
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
54+
def test_adam_8bit_correctness(self, optim_name):
5855
device = "cuda"
5956
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
6057
model2 = copy.deepcopy(model1)
6158

62-
optim1 = bnb_optim_cls(model1.parameters())
63-
optim2 = optim_cls(model2.parameters())
59+
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
60+
optim2 = getattr(optim_8bit, optim_name)(model2.parameters())
6461

6562
for _ in range(2):
6663
x = torch.randn(4, 32, device=device)

torchao/prototype/optim_8bit/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ This folder implements 8-bit optimizers using dynamic tree quantization as outli
77
This is a drop-in replacement for `torch.optim.Adam`
88

99
```python
10-
from torchao.prototype.optim_8bit import AdamDTQ8bit
10+
from torchao.prototype.optim_8bit import Adam8bit
1111

1212
model = ...
13-
optim = AdamDTQ8bit(model.parameters())
13+
optim = Adam8bit(model.parameters())
1414
```
1515

1616
You can also change quantization block size (default 2048) by passing `block_size=value` to the optimizer.
1717

18-
**Other optimizers**: AdamW is also available as `AdamWDTQ8bit`.
18+
**Other optimizers**: AdamW is also available as `AdamW8bit`.
1919

2020
NOTE: this requires PyTorch >= 2.3
2121

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .adam import AdamDTQ8bit
2-
from .adamw import AdamWDTQ8bit
1+
from .adam import Adam8bit
2+
from .adamw import AdamW8bit

torchao/prototype/optim_8bit/adam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .subclass import maybe_new_zero_buffer
88

99

10-
class AdamDTQ8bit(Optimizer):
10+
class Adam8bit(Optimizer):
1111
def __init__(
1212
self,
1313
params,

torchao/prototype/optim_8bit/adamw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .subclass import maybe_new_zero_buffer
88

99

10-
class AdamWDTQ8bit(Optimizer):
10+
class AdamW8bit(Optimizer):
1111
def __init__(
1212
self,
1313
params,

0 commit comments

Comments
 (0)