Skip to content

Commit

Permalink
[FEATURE] BaseQuantLinear add SUPPORTED_DEVICES (ModelCloud#174)
Browse files Browse the repository at this point in the history
* Check QuantLinear Device

* cleanup

* REFRACTOR check_cuda by introducing SUPPORTED_DEVICE into BaseQuantLinear

* make device type cuda/cpu an enum

* cleanup

* cleanup
  • Loading branch information
ZX-ModelCloud authored Jul 5, 2024
1 parent 95c57fd commit 590b992
Show file tree
Hide file tree
Showing 31 changed files with 185 additions and 160 deletions.
6 changes: 3 additions & 3 deletions examples/benchmark/generation_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
from datasets import Dataset, load_dataset
from gptqmodel import Backend, GPTQModel, QuantizeConfig, get_backend
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_backend
from tqdm import tqdm
from transformers import AutoTokenizer, GenerationConfig
from transformers.generation.logits_process import LogitsProcessor
Expand Down Expand Up @@ -143,7 +143,7 @@ def tokenize(examples):

def load_model_tokenizer(
model_name_or_path: str,
backend: Backend,
backend: BACKEND,
tokenizer_name_or_path: Optional[str] = None,
from_pretrained: bool = False,
max_memory: Optional[dict] = None,
Expand Down Expand Up @@ -280,7 +280,7 @@ def main():
logger.info(f"quantize config: {model.quantize_config.to_dict()}")
logger.info(f"model device map: {model.hf_device_map}")

if args.backend == Backend.TRITON:
if args.backend == BACKEND.TRITON:
logger.info("warmup triton, this may take a while.")
model.warmup_triton()

Expand Down
2 changes: 1 addition & 1 deletion examples/benchmark/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
parser.add_argument("--backend", choices=['AUTO', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'], help="Whether to use Backend format")
parser.add_argument("--backend", choices=['AUTO', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'], help="Whether to use BACKEND format")
args = parser.parse_args()

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand Down
6 changes: 3 additions & 3 deletions examples/quantization/basic_usage_bitblas.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
from gptqmodel import Backend, GPTQModel
from gptqmodel import BACKEND, GPTQModel
from gptqmodel.quantization import QuantizeConfig
from transformers import AutoTokenizer, TextGenerationPipeline

backend = Backend.BITBLAS
backend = BACKEND.BITBLAS
pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "./facebook/opt-125m-4bit-128g"

if backend == Backend.BITBLAS:
if backend == BACKEND.BITBLAS:
quantized_model_dir += "-bitblas"

def main():
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .models import GPTQModel
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import Backend, get_backend
from .utils import BACKEND, get_backend
from .utils.exllama import exllama_set_max_input_length
from .version import __version__
8 changes: 4 additions & 4 deletions gptqmodel/integration/optimum/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from accelerate.hooks import remove_hook_from_module

from ...quantization import FORMAT, FORMAT_FIELD_JSON, GPTQ, QuantizeConfig
from ...utils.backend import Backend
from ...utils.backend import BACKEND
from ...utils.exllama import exllama_set_max_input_length
from ...utils.importer import select_quant_linear
from ...utils.model import convert_gptq_v1_to_v2_format, convert_gptq_v2_to_v1_format, gptqmodel_post_init
Expand Down Expand Up @@ -644,11 +644,11 @@ def pack_model(

def select_quantlinear(self):
if self.exllama_version == ExllamaVersion.ONE:
backend = Backend.EXLLAMA
backend = BACKEND.EXLLAMA
elif self.exllama_version == ExllamaVersion.TWO:
backend = Backend.EXLLAMA_V2
backend = BACKEND.EXLLAMA_V2
else:
backend = Backend.AUTO
backend = BACKEND.AUTO

QuantLinear = select_quant_linear(
sym=self.sym,
Expand Down
13 changes: 13 additions & 0 deletions gptqmodel/models/_const.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from torch import device
from enum import Enum

CPU = device("cpu")
CUDA = device("cuda")
CUDA_0 = device("cuda:0")

class DEVICE(Enum):
CPU = "cpu"
CUDA = "cuda"


def get_device_by_type(type_value: str):
for enum_constant in DEVICE:
if enum_constant.value == type_value:
return enum_constant
raise ValueError(f"Invalid type_value str: {type_value}")

SUPPORTED_MODELS = [
"bloom",
"gptj",
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional, Union

from ..utils import Backend
from ..utils import BACKEND
from ..utils.model import check_and_get_model_type
from .baichuan import BaiChuanGPTQ
from .base import BaseGPTQModel, QuantizeConfig
Expand Down Expand Up @@ -109,7 +109,7 @@ def from_quantized(
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
backend: Backend = Backend.AUTO,
backend: BACKEND = BACKEND.AUTO,
quantize_config: Optional[QuantizeConfig | Dict] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
Expand Down
55 changes: 26 additions & 29 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@
from ..quantization import GPTQ, QuantizeConfig
from ..quantization.config import (FORMAT, FORMAT_FIELD_JSON, META_FIELD_QUANTIZER,
META_QUANTIZER_GPTQMODEL, MIN_VERSION_WITH_V2, QUANTIZE_BLACK_LIST)
from ..utils.backend import Backend
from ..utils.backend import BACKEND
from ..utils.bitblas import convert_to_bitblas, prepare_model_for_bitblas_load
from ..utils.data import collate_data
from ..utils.importer import select_quant_linear
from ..utils.marlin import (_validate_marlin_compatibility,
_validate_marlin_device_support, prepare_model_for_marlin_load)
from ..utils.model import (auto_dtype_from_config, check_cuda, convert_gptq_v1_to_v2_format,
from ..utils.model import (auto_dtype_from_config, convert_gptq_v1_to_v2_format,
convert_gptq_v2_to_v1_format, find_layers, get_checkpoints, get_device,
get_module_by_name_prefix, get_module_by_name_suffix, get_moe_layer_modules,
gptqmodel_post_init, make_quant, move_to, nested_move_to, pack_model, simple_dispatch_model,
verify_model_hash, verify_sharded_model_hashes)
from ..utils.device import check_cuda
from ..version import __version__
from ._const import CPU, CUDA_0, SUPPORTED_MODELS
from ._const import CPU, DEVICE, CUDA_0, SUPPORTED_MODELS

logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
Expand Down Expand Up @@ -420,7 +421,7 @@ def tmp(_, inp, out):
quantizers=quantizers,
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
backend=Backend.AUTO,
backend=BACKEND.AUTO,
desc_act=self.quantize_config.desc_act,
warmup_triton=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu,
Expand Down Expand Up @@ -737,7 +738,7 @@ def from_quantized(
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
backend: Backend = Backend.AUTO,
backend: BACKEND = BACKEND.AUTO,
torch_dtype: [str | torch.dtype] = "auto",
quantize_config: Optional[QuantizeConfig] = None,
model_basename: Optional[str] = None,
Expand All @@ -749,12 +750,8 @@ def from_quantized(
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
):
# TODO REFRACTOR check_cuda by introducing SUPPORTED_DEVICE into BaseQuantLinear
if backend != Backend.QBITS and backend != Backend.AUTO:
check_cuda()

if backend == Backend.QBITS:
device = torch.device("cpu")
if backend == BACKEND.QBITS:
device = CPU
try:
pass
except Exception as e:
Expand All @@ -765,8 +762,8 @@ def from_quantized(
if torch_dtype is None or torch_dtype == "auto":
torch_dtype = qbits_dtype()

if backend != Backend.QBITS and not torch.cuda.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires CUDA gpu. Please set backend=Backend.QBITS for cpu only quantization and inference.")
if backend != BACKEND.QBITS and not torch.cuda.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires CUDA gpu. Please set backend=BACKEND.QBITS for cpu only quantization and inference.")

"""load quantized model from local disk"""
if cls.require_trust_remote_code and not trust_remote_code:
Expand Down Expand Up @@ -823,24 +820,24 @@ def from_quantized(

if quantize_config.format == FORMAT.MARLIN:
# format marlin requires marlin kernel
if backend != Backend.MARLIN and backend != Backend.AUTO:
raise TypeError(f"FORMAT.MARLIN requires Backend.AUTO or Backend.MARLIN: actual = `{backend}`.")
backend = Backend.MARLIN
if backend != BACKEND.MARLIN and backend != BACKEND.AUTO:
raise TypeError(f"FORMAT.MARLIN requires BACKEND.AUTO or BACKEND.MARLIN: actual = `{backend}`.")
backend = BACKEND.MARLIN

marlin_compatible = False if backend == Backend.QBITS else _validate_marlin_device_support()
marlin_compatible = False if backend == BACKEND.QBITS else _validate_marlin_device_support()

if backend != Backend.MARLIN:
if backend != BACKEND.MARLIN:
unsupported = _validate_marlin_compatibility(quantize_config)
if unsupported is None and marlin_compatible:
logger.info(
"You passed a model that is compatible with the Marlin int4*fp16 GPTQ kernel but backend is not Backend.MARLIN. We recommend using `backend=Backend.MARLIN` to use the optimized Marlin kernels for inference. Example: `model = GPTQModel.from_quantized(..., backend=Backend.MARLIN)`."
"You passed a model that is compatible with the Marlin int4*fp16 GPTQ kernel but backend is not BACKEND.MARLIN. We recommend using `backend=BACKEND.MARLIN` to use the optimized Marlin kernels for inference. Example: `model = GPTQModel.from_quantized(..., backend=BACKEND.MARLIN)`."
)

if quantize_config.format == FORMAT.BITBLAS:
# format bitblas requires bitblas kernel
if backend != Backend.BITBLAS and backend != Backend.AUTO:
raise TypeError(f"FORMAT.BITBLAS requires Backend.AUTO or Backend.BITBLAS: actual = `{backend}`.")
backend = Backend.BITBLAS
if backend != BACKEND.BITBLAS and backend != BACKEND.AUTO:
raise TypeError(f"FORMAT.BITBLAS requires BACKEND.AUTO or BACKEND.BITBLAS: actual = `{backend}`.")
backend = BACKEND.BITBLAS

if model_basename is None:
if quantize_config.model_file_base_name:
Expand Down Expand Up @@ -936,7 +933,7 @@ def skip(*args, **kwargs):
layers,
quantize_config.bits,
quantize_config.group_size,
backend=backend.AUTO if backend == Backend.MARLIN or backend == Backend.BITBLAS else backend,
backend=backend.AUTO if backend == BACKEND.MARLIN or backend == BACKEND.BITBLAS else backend,
format=FORMAT.GPTQ_V2,
desc_act=quantize_config.desc_act,
)
Expand All @@ -961,7 +958,7 @@ def skip(*args, **kwargs):
if device is not None:
device = torch.device(device)
if not max_memory and not device_map:
device_map = {"": device.index if device.type == "cuda" else device.type}
device_map = {"": device.index if device.type == DEVICE.CUDA else device.type}
if not isinstance(device_map, dict) and device_map != "sequential":
max_memory = accelerate.utils.get_balanced_memory(
model=model,
Expand Down Expand Up @@ -1004,14 +1001,14 @@ def skip(*args, **kwargs):
load_checkpoint_in_model = True
quantize_config.format = FORMAT.GPTQ_V2

if backend == Backend.MARLIN:
if backend == BACKEND.MARLIN:
if is_sharded:
raise ValueError(
"The loading of sharded checkpoints with Marlin is currently not supported."
)
if not _validate_marlin_device_support():
raise ValueError(
f'Marlin kernel does not support this gpu with compute capability of `{torch.cuda.get_device_capability()}`. Please do not use `back=Backend.MARLIN`.'
f'Marlin kernel does not support this gpu with compute capability of `{torch.cuda.get_device_capability()}`. Please do not use `back=BACKEND.MARLIN`.'
)

# Validate the model can run in Marlin.
Expand All @@ -1034,7 +1031,7 @@ def skip(*args, **kwargs):
load_checkpoint_in_model=load_checkpoint_in_model,
)

if backend == Backend.BITBLAS:
if backend == BACKEND.BITBLAS:
if is_sharded:
raise ValueError(
"The loading of sharded checkpoints with BitBLAS is currently not supported. Please raise an issue in GPTQModel repository.")
Expand All @@ -1055,7 +1052,7 @@ def skip(*args, **kwargs):

# If we use marlin or bitblas to load the quantized model, the model is already a converted model,
# and we no longer need to call load_checkpoint_in_model()
if not load_checkpoint_in_model and backend != Backend.MARLIN and backend != Backend.BITBLAS:
if not load_checkpoint_in_model and backend != BACKEND.MARLIN and backend != BACKEND.BITBLAS:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
Expand Down Expand Up @@ -1095,7 +1092,7 @@ def skip(*args, **kwargs):
model.eval()

# == step6: (optional) warmup triton == #
if backend == Backend.TRITON and warmup_triton:
if backend == BACKEND.TRITON and warmup_triton:
from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear

TritonV2QuantLinear.warmup(model, seqlen=model.seqlen)
Expand Down
30 changes: 24 additions & 6 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
import torch.nn as nn

from ...models._const import DEVICE, get_device_by_type
from ...utils.device import check_cuda

class BaseQuantLinear(nn.Module):

SUPPORTED_BITS = []
SUPPORTED_GROUP_SIZE = []
SUPPORTED_DESC_ACT = [True, False]
SUPPORTED_SYM = [True, False]
SUPPORTED_SHARDS: bool = True
SUPPORTED_DEVICES = [DEVICE.CUDA]

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, *args, **kwargs):
super().__init__()
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym)
if err:
raise NotImplementedError(err)

if DEVICE.CUDA in self.SUPPORTED_DEVICES:
check_cuda()

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_error: bool = True) -> bool:
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool) -> bool:
validate, _ = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym)
return validate

@classmethod
def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, ):
validate = True
err = ""
if cls.SUPPORTED_BITS and bits not in cls.SUPPORTED_BITS:
Expand All @@ -25,11 +41,13 @@ def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_e
elif cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT:
validate = False
err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`"
return validate, err

if not validate and raise_error:
raise NotImplementedError(err)

return validate
@classmethod
def validate_device(cls, device_type: str):
device = get_device_by_type(device_type)
if cls.SUPPORTED_DEVICES and device not in cls.SUPPORTED_DEVICES:
raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTED_DEVICES}` bits: actual device = `{device}`")

# override me
def post_init(self):
Expand Down
7 changes: 3 additions & 4 deletions gptqmodel/nn_modules/qlinear/qlinear_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def __init__(
self,
bits: int,
group_size: int,
sym: bool,
desc_act: bool,
sym: bool,
infeatures: int,
outfeatures: int,
bias: bool,
Expand All @@ -104,13 +104,11 @@ def __init__(
layout: str = "nt",
**kwargs,
):
super().__init__()
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

# TODO: remove delayed import after bitblas whl support for 11.7, 11.8, 12.0 are added
import_bitblas()

self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act)

self._validate_parameters(group_size, infeatures, outfeatures)

self.bits = bits
Expand Down Expand Up @@ -243,6 +241,7 @@ def reset_parameters(self):
self.q_params = None

def post_init(self):
self.validate_device(self.qweight.device.type)
# eliminate runtime overhead like exllama state
param_list = [self.qweight, self.scales, self.zeros]
if self.bitblas_matmul.config.with_bias:
Expand Down
7 changes: 3 additions & 4 deletions gptqmodel/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ class ExllamaQuantLinear(BaseQuantLinear):

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs,):
super().__init__()
self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act)
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs,):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
Expand Down Expand Up @@ -92,7 +91,7 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat
self.bias = None

def post_init(self):
assert self.qweight.device.type == "cuda"
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
Expand Down
Loading

0 comments on commit 590b992

Please sign in to comment.