Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 48 additions & 45 deletions gptqmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,51 @@
patch_safetensors_save_file()
patch_triton_autotuner()

from .utils.env import env_flag
from .utils.logger import setup_logger


DEBUG_ON = env_flag("DEBUG")

from .utils.linalg_warmup import run_torch_linalg_warmup
from .utils.threadx import DeviceThreadPool


DEVICE_THREAD_POOL = DeviceThreadPool(
inference_mode=True,
warmups={
"cuda": run_torch_linalg_warmup,
"xpu": run_torch_linalg_warmup,
"mps": run_torch_linalg_warmup,
"cpu": run_torch_linalg_warmup,
},
workers={
"cuda:per": 4,
"xpu:per": 1,
"mps": 8,
"cpu": min(12, max(1, (os.cpu_count() or 1) // 2)),
"model_loader:cpu": 2,
},
empty_cache_every_n=512,
)

from .models import GPTQModel, get_best_device
from .models.auto import ASCII_LOGO
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import BACKEND
from .utils.exllama import exllama_set_max_input_length
from .version import __version__


setup_logger().info("\n%s", ASCII_LOGO)


if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']:
try:
from modelscope.utils.hf_util.patcher import patch_hub
patch_hub()
except Exception:
raise ModuleNotFoundError("you have set GPTQMODEL_USE_MODELSCOPE env, but doesn't have modelscope? install it with `pip install modelscope`")
if os.environ.get("GPTQMODEL_SKIP_INIT", "0") == "1":
__all__ = []
else:
from .utils.env import env_flag
from .utils.logger import setup_logger


DEBUG_ON = env_flag("DEBUG")

from .utils.linalg_warmup import run_torch_linalg_warmup
from .utils.threadx import DeviceThreadPool


DEVICE_THREAD_POOL = DeviceThreadPool(
inference_mode=True,
warmups={
"cuda": run_torch_linalg_warmup,
"xpu": run_torch_linalg_warmup,
"mps": run_torch_linalg_warmup,
"cpu": run_torch_linalg_warmup,
},
workers={
"cuda:per": 4,
"xpu:per": 1,
"mps": 8,
"cpu": min(12, max(1, (os.cpu_count() or 1) // 2)),
"model_loader:cpu": 2,
},
empty_cache_every_n=512,
)

from .models import GPTQModel, get_best_device
from .models.auto import ASCII_LOGO
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import BACKEND
from .utils.exllama import exllama_set_max_input_length
from .version import __version__


setup_logger().info("\n%s", ASCII_LOGO)


if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']:
try:
from modelscope.utils.hf_util.patcher import patch_hub
patch_hub()
except Exception:
raise ModuleNotFoundError("you have set GPTQMODEL_USE_MODELSCOPE env, but doesn't have modelscope? install it with `pip install modelscope`")
14 changes: 10 additions & 4 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset
self.export_compatible = False

self.version = qcfg.format
if self.version == FORMAT.GEMM:
self.version = FORMAT.GEMM_V2
elif self.version == FORMAT.GEMV:
self.version = FORMAT.GEMV_V2

# TODO Can it be configured?
# The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len.
Expand Down Expand Up @@ -675,13 +679,13 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time

linear_layer.weight.data = wq

if self.version == "gemm":
if self.version in ("gemm", "gemm_v2"):
scales = scales.t().contiguous()
if zeros is not None:
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM

elif self.version == "gemv":
elif self.version in ("gemv", "gemv_v2"):
q_linear_module = WQLinear_GEMV

elif self.version == "marlin":
Expand Down Expand Up @@ -790,9 +794,11 @@ def submodule_finalize(self, module: NamedModule, **kwargs):
module.state.pop("w", None) # no need for original weights now

def finalize(self, model: BaseQModel, **kwargs):
if model.quantize_config.format == FORMAT.GEMM:
if model.quantize_config.format in (FORMAT.GEMM, FORMAT.GEMM_V2):
model.quantize_config.format = FORMAT.GEMM_V2
model.qlinear_kernel = AwqGEMMQuantLinear
elif model.quantize_config.format == FORMAT.GEMV:
elif model.quantize_config.format in (FORMAT.GEMV, FORMAT.GEMV_V2):
model.quantize_config.format = FORMAT.GEMV_V2
model.qlinear_kernel = AwqGEMVQuantLinear
elif model.quantize_config.format == FORMAT.GEMV_FAST:
model.qlinear_kernel = AwqGEMVFastQuantLinear
Expand Down
7 changes: 3 additions & 4 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,11 +841,10 @@ def quantize(
)

if self.quantize_config.quant_method == METHOD.AWQ:
if self.quantize_config.format == FORMAT.GEMV_FAST:
# AWQ GEMV_FAST only supports pack_dtype is torch.int16
log.info("Quantize Model: Auto fix `pack_dtype` to `torch.int16`")
if self.quantize_config.format in (FORMAT.GEMM, FORMAT.GEMM_V2, FORMAT.GEMV, FORMAT.GEMV_V2, FORMAT.GEMV_FAST):
log.info("Quantize Model: Auto fix `pack_dtype` to `torch.int16` for AWQ layout")
self.quantize_config.pack_dtype = torch.int16
elif self.quantize_config.format == FORMAT.MARLIN:
if self.quantize_config.format == FORMAT.MARLIN:
# AWQ MARLIN only supports zero_point is false
log.info("Quantize Model: Auto fix `zero_point` to `False`")
self.quantize_config.zero_point = False
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def from_quantized(

if backend == BACKEND.VLLM or backend == BACKEND.SGLANG:
if backend == BACKEND.VLLM:
if qcfg.format != FORMAT.GPTQ and qcfg.format != FORMAT.GEMM:
if qcfg.format not in (FORMAT.GPTQ, FORMAT.GEMM, FORMAT.GEMM_V2):
raise ValueError(f"{backend} backend only supports FORMAT.GPTQ or FORMAT.GEMM: actual = {qcfg.format}")
elif backend == BACKEND.SGLANG:
if qcfg.format != FORMAT.GPTQ:
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,8 @@ def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_
# print("self qw", self.qweight, self.scales, self.qzeros)

class AWQuantLinear(BaseQuantLinear):
REQUIRES_FORMAT_V2 = False

def __init__(self,
bias: bool = False,
register_buffers: bool = False,
Expand Down
Loading