Skip to content

migrate nf4 to configs #1857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 65 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
73ea701
Update
vkuzo Mar 7, 2025
4f2c69d
Update
vkuzo Mar 7, 2025
75fb698
Update
vkuzo Mar 7, 2025
5796efc
Update
vkuzo Mar 7, 2025
3b016e7
Update
vkuzo Mar 7, 2025
11bce08
Update
vkuzo Mar 7, 2025
3a91a06
Update
vkuzo Mar 7, 2025
91c350f
Update
vkuzo Mar 7, 2025
125546b
Update
vkuzo Mar 7, 2025
130a52e
Update
vkuzo Mar 7, 2025
70e0e53
Update
vkuzo Mar 7, 2025
049d750
Update
vkuzo Mar 7, 2025
7b8fb8d
Update
vkuzo Mar 7, 2025
1ea0e2f
Update
vkuzo Mar 7, 2025
03ea2e4
Update
vkuzo Mar 7, 2025
cf3ad33
Update
vkuzo Mar 7, 2025
19ac99d
Update
vkuzo Mar 7, 2025
5deed22
Update
vkuzo Mar 7, 2025
2a1f7b2
Update
vkuzo Mar 7, 2025
5fa0e27
Update
vkuzo Mar 7, 2025
160cc29
Update
vkuzo Mar 7, 2025
7e40b15
Update
vkuzo Mar 7, 2025
0ecb02d
Update
vkuzo Mar 7, 2025
7cb810c
Update
vkuzo Mar 7, 2025
eb567cd
Update
vkuzo Mar 7, 2025
cd97b30
Update
vkuzo Mar 7, 2025
de38b6e
Update
vkuzo Mar 7, 2025
bfba1d9
Update
vkuzo Mar 7, 2025
9ac2334
Update
vkuzo Mar 7, 2025
6f3d127
Update
vkuzo Mar 8, 2025
cda5d18
Update
vkuzo Mar 8, 2025
96d74a3
Update
vkuzo Mar 8, 2025
c83c029
Update
vkuzo Mar 8, 2025
6f1c92d
Update
vkuzo Mar 8, 2025
95be23e
Update
vkuzo Mar 8, 2025
0776629
Update
vkuzo Mar 8, 2025
fdb292e
Update
vkuzo Mar 8, 2025
706ff1f
Update
vkuzo Mar 8, 2025
ac2314e
Update
vkuzo Mar 8, 2025
8002c39
Update
vkuzo Mar 8, 2025
a4dfaa1
Update
vkuzo Mar 8, 2025
ecdab3b
Update
vkuzo Mar 8, 2025
0506d32
Update
vkuzo Mar 8, 2025
237a72a
Update
vkuzo Mar 8, 2025
7183e83
Update
vkuzo Mar 8, 2025
10d0dff
Update
vkuzo Mar 8, 2025
be63b3c
Update
vkuzo Mar 8, 2025
2ec7827
Update
vkuzo Mar 8, 2025
3f10bc5
Update
vkuzo Mar 8, 2025
fa8c0f1
Update
vkuzo Mar 8, 2025
e8ee9a1
Update
vkuzo Mar 8, 2025
50a7f9f
Update
vkuzo Mar 8, 2025
8c8d7e4
Update
vkuzo Mar 8, 2025
5b62372
Update
vkuzo Mar 8, 2025
34ce5f4
Update
vkuzo Mar 10, 2025
c038451
Update
vkuzo Mar 10, 2025
c3815dc
Update
vkuzo Mar 10, 2025
4f195b6
Update
vkuzo Mar 10, 2025
0d388c8
Update
vkuzo Mar 10, 2025
db3e3d3
Update
vkuzo Mar 12, 2025
5f742de
Update
vkuzo Mar 12, 2025
6500a26
Update
vkuzo Mar 12, 2025
bdf1aea
Update
vkuzo Mar 12, 2025
c2577a2
Update
vkuzo Mar 12, 2025
82b6281
Update
vkuzo Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@

import torchao
from packaging import version
from torchao.dtypes._nf4tensor_api import nf4_weight_only
from torchao.dtypes.nf4tensor import (
_INNER_TENSOR_NAMES_FOR_SHARDING,
NF4Tensor,
linear_nf4,
nf4_weight_only,
to_nf4,
)
from torchao.testing.utils import skip_if_rocm
Expand Down
34 changes: 34 additions & 0 deletions torchao/dtypes/_nf4tensor_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch

from torchao.core.config import AOBaseConfig
from torchao.dtypes.nf4tensor import NF4Tensor
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)


class NF4WeightOnlyConfig(AOBaseConfig):
"""
Note: the file location of this workflow is temporary.
TODO(future PR): integrate this properly into torchao's directory structure
"""

block_size: int = 64
scaler_block_size: int = 256


# for bc
nf4_weight_only = NF4WeightOnlyConfig


@register_quantize_module_handler(NF4WeightOnlyConfig)
def _nf4_weight_only_transform(
module: torch.nn.Module,
config: NF4WeightOnlyConfig,
) -> torch.nn.Module:
block_size = config.block_size
scaler_block_size = config.scaler_block_size

new_weight = NF4Tensor.from_tensor(module.weight, block_size, scaler_block_size)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
return module
9 changes: 0 additions & 9 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,15 +985,6 @@ def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)


def nf4_weight_only(block_size: int = 64, scaler_block_size: int = 256):
from torchao.quantization.quant_api import _get_linear_subclass_inserter

def _to_nf4(tensor: torch.Tensor):
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)

return _get_linear_subclass_inserter(_to_nf4)


NF4_TORCH_FUNCTIONS = {}


Expand Down
Loading