Skip to content

Commit 1f822d3

Browse files
committed
migrate nf4 to configs
Summary: 1. migrate nf4 to configs 2. create a separate, temporary, workflow file for nf4 to get around circular import errors. A future person should properly make nf4 a workflow in torchao's dir structure, but that is out of scope for this PR. Test Plan: ``` pytest test/dtypes/test_nf4.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 160d039 ghstack-comment-id: 2707042209 Pull Request resolved: #1857
1 parent ed695a1 commit 1f822d3

File tree

3 files changed

+35
-10
lines changed

3 files changed

+35
-10
lines changed

test/dtypes/test_nf4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626

2727
import torchao
2828
from packaging import version
29+
from torchao.dtypes._nf4tensor_api import nf4_weight_only
2930
from torchao.dtypes.nf4tensor import (
3031
_INNER_TENSOR_NAMES_FOR_SHARDING,
3132
NF4Tensor,
3233
linear_nf4,
33-
nf4_weight_only,
3434
to_nf4,
3535
)
3636
from torchao.testing.utils import skip_if_rocm

torchao/dtypes/_nf4tensor_api.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
3+
from torchao.core.config import AOBaseConfig
4+
from torchao.dtypes.nf4tensor import NF4Tensor
5+
from torchao.quantization.transform_module import (
6+
register_quantize_module_handler,
7+
)
8+
9+
10+
class NF4WeightOnlyConfig(AOBaseConfig):
11+
"""
12+
Note: the file location of this workflow is temporary.
13+
TODO(future PR): integrate this properly into torchao's directory structure
14+
"""
15+
16+
block_size: int = 64
17+
scaler_block_size: int = 256
18+
19+
20+
# for bc
21+
nf4_weight_only = NF4WeightOnlyConfig
22+
23+
24+
@register_quantize_module_handler(NF4WeightOnlyConfig)
25+
def _nf4_weight_only_transform(
26+
module: torch.nn.Module,
27+
config: NF4WeightOnlyConfig,
28+
) -> torch.nn.Module:
29+
block_size = config.block_size
30+
scaler_block_size = config.scaler_block_size
31+
32+
new_weight = NF4Tensor.from_tensor(module.weight, block_size, scaler_block_size)
33+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
34+
return module

torchao/dtypes/nf4tensor.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -985,15 +985,6 @@ def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
985985
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)
986986

987987

988-
def nf4_weight_only(block_size: int = 64, scaler_block_size: int = 256):
989-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
990-
991-
def _to_nf4(tensor: torch.Tensor):
992-
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)
993-
994-
return _get_linear_subclass_inserter(_to_nf4)
995-
996-
997988
NF4_TORCH_FUNCTIONS = {}
998989

999990

0 commit comments

Comments
 (0)