Skip to content

add int4 non-gptq and bugfixes #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def test_gptq_quantizer_gpt_fast(self):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down Expand Up @@ -357,6 +356,41 @@ def test_gptq_quantizer_int4wo(self):
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)
groupsize = 128
quantizer = Int4WeightOnlyQuantizer(
groupsize,
)
model = quantizer.quantize(model).cuda()
result = TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert result['results']['wikitext']['word_perplexity,none'] < 8.24, (
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper(self):
from torchao.quantization.GPTQ import TransformerEvalWrapper
Expand Down
106 changes: 94 additions & 12 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
pack_tinygemm_scales_and_zeros,
groupwise_affine_quantize_tensor,
)
aten = torch.ops.aten

Expand Down Expand Up @@ -65,8 +66,8 @@

__all__ = [
"MultiInput",
"WeightOnlyInt4Linear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
] + add_ons

if lm_eval_available:
Expand Down Expand Up @@ -117,7 +118,10 @@ def __init__(

@property
def eot_token_id(self):
return self._tokenizer.eos_id()
try:
return self._tokenizer.eos_id()
except:
return self._tokenizer.eos_id

@property
def max_length(self):
Expand All @@ -139,7 +143,10 @@ def tok_encode(self, string: str, **kwargs):
# TODO: verify this for multi-batch as well
tokens = self._tokenizer.encode(string)
if hasattr(self._tokenizer, "bos_id"):
tokens = [self._tokenizer.bos_id()] + tokens
try:
tokens = [self._tokenizer.bos_id()] + tokens
except:
tokens = [self._tokenizer.bos_id] + tokens
return tokens

def tok_decode(self, tokens):
Expand Down Expand Up @@ -747,6 +754,12 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module:
pass

def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
k_divisible_by_groupsize = k % groupsize == 0
if inner_k_tiles is not None:
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
return k_divisible_by_groupsize

def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
Expand All @@ -767,7 +780,7 @@ def __init__(
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
) -> None:
super().__init__()
self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
from model import find_multiple
self.origin_in_features = in_features
Expand Down Expand Up @@ -806,14 +819,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)


def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
k_divisible_by_groupsize = k % groupsize == 0
if inner_k_tiles is not None:
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
return k_divisible_by_groupsize

def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None):

for name, child in module.named_children():
Expand All @@ -826,6 +831,83 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func)

class Int4WeightOnlyQuantizer(Quantizer):
def __init__(
self,
groupsize: int = 256,
padding_allowed: bool = True,
inner_k_tiles: Optional[int] = 8,
) -> None:
super().__init__()
assert inner_k_tiles in [2, 4, 8]
assert groupsize in [32, 64, 128, 256]

self.inner_k_tiles = inner_k_tiles
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed

@torch.no_grad()
def _create_quantized_state_dict(
self, model: torch.nn.Module
) -> Dict[str, torch.Tensor]:
cur_state_dict = model.state_dict()
for fqn, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
# assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")

assert (
in_features % self.groupsize == 0
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"

weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if self.padding_allowed:
from .utils import find_multiple
import torch.nn.functional as F
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
else:
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
"and that groupsize and inner_k_tiles*16 evenly divide into it")
continue
(
w_int4x8,
scales_and_zeros
) = groupwise_affine_quantize_tensor(
weight,
4, # n_bit
self.groupsize,
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda")
return cur_state_dict

def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
replace_linear_int4(
model,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
)
return model

def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
state_dict = self._create_quantized_state_dict(model)
model = self._convert_for_runtime(model)
# TODO: make it strict
model.load_state_dict(state_dict, strict=False)
return model

class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@
"compute_error",
"get_model_size_in_bytes",
"WeightOnlyInt8QuantLinear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
]
2 changes: 2 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .unified import Quantizer, TwoStepQuantizer
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)


Expand All @@ -45,6 +46,7 @@
"Quantizer",
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer"
]

if TORCH_VERSION_AFTER_2_3:
Expand Down
1 change: 0 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def pack_tinygemm_scales_and_zeros(scales, zeros):

def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)


Expand Down
Loading