Skip to content

Commit

Permalink
[REFRACTOR] remove use_cuda_fp16 argument (#97)
Browse files Browse the repository at this point in the history
* Remove use_cuda_fp16 arg

* fix test_serialization.py
  • Loading branch information
ZX-ModelCloud authored Jun 28, 2024
1 parent 4ada383 commit 9238f99
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 47 deletions.
1 change: 0 additions & 1 deletion examples/benchmark/generation_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def load_model_tokenizer(
model = GPTQModel.from_quantized(
model_name_or_path,
max_memory=max_memory,
use_cuda_fp16=True,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
Expand Down
2 changes: 0 additions & 2 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def from_quantized(
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
backend: Backend = Backend.AUTO,
use_cuda_fp16: bool = True,
quantize_config: Optional[QuantizeConfig | Dict] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
Expand All @@ -138,7 +137,6 @@ def from_quantized(
max_memory=max_memory,
device=device,
backend=backend,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
Expand Down
13 changes: 0 additions & 13 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,6 @@ def quantize(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,

# TODO: remove use_cuda_fp16 arg..why? doesn't pass smell test @ZX-ModelCloud
use_cuda_fp16: bool = True,

autotune_warmup_after_quantized: bool = False,
calibration_enable_gpu_cache: bool = True,
):
Expand Down Expand Up @@ -424,7 +420,6 @@ def tmp(_, inp, out):
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
backend=Backend.AUTO,
use_cuda_fp16=use_cuda_fp16,
desc_act=self.quantize_config.desc_act,
warmup_triton=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu,
Expand Down Expand Up @@ -773,11 +768,8 @@ def from_quantized(
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,

backend: Backend = Backend.AUTO,

torch_dtype: [str | torch.dtype] = "auto",
use_cuda_fp16: bool = True,
quantize_config: Optional[QuantizeConfig] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
Expand Down Expand Up @@ -917,10 +909,6 @@ def from_quantized(
def skip(*args, **kwargs):
pass

if torch_dtype != torch.float16:
logger.warning("Overriding use_cuda_fp16 to False since torch_dtype is not torch.float16.")
use_cuda_fp16 = False

torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
Expand Down Expand Up @@ -962,7 +950,6 @@ def skip(*args, **kwargs):
quantize_config.group_size,
backend=backend.AUTO if backend == Backend.MARLIN or backend == Backend.BITBLAS else backend,
format=FORMAT.GPTQ_V2,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
)
model.tie_weights()
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
infeatures: int ,
outfeatures: int,
bias: bool,
use_cuda_fp16=True,
kernel_switch_threshold=128,
weight_dtype=torch.float16,
**kwargs,
Expand Down Expand Up @@ -71,7 +70,7 @@ def __init__(
self.bias = None
self.half_indim = self.infeatures // 2

self.use_cuda_fp16 = use_cuda_fp16 if bits != 8 else False
self.use_cuda_fp16 = weight_dtype == torch.float16 if bits != 8 else False

# is performed by unpacking the weights and using torch.matmul
if self.bits in [2, 4, 8]:
Expand Down Expand Up @@ -195,7 +194,7 @@ def forward(self, x):
if self.use_cuda_fp16:
if x_dtype != torch.float16:
logger.warning_once(
f"The cuda-old kernel for GPTQ with use_cuda_fp16=True requires a float16 input activation, while {x_dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
f"The cuda-old kernel for GPTQ requires a float16 input activation, while {x_dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
)

if self.bits == 2:
Expand Down
36 changes: 10 additions & 26 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def make_quant(
format: str,
desc_act: bool = False,
sym: bool = True,
use_cuda_fp16: bool = True,
pack: bool = False,
) -> BaseQuantLinear:
select_quant_linear_func = select_quant_linear_with_pack if pack else select_quant_linear
Expand Down Expand Up @@ -143,29 +142,16 @@ def make_quant(
raise NotImplementedError(f"Unsupported module {submodule}")

bias = submodule.bias is not None
if (not (desc_act) or group_size == -1) and backend != Backend.TRITON:
new_layer = QuantLinear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
use_cuda_fp16=use_cuda_fp16,
weight_dtype=submodule.weight.dtype,
)
else:
new_layer = QuantLinear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
weight_dtype=submodule.weight.dtype,
)
new_layer = QuantLinear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
weight_dtype=submodule.weight.dtype,
)
new_layer.device = ori_layer_device
recurse_setattr(module, name, new_layer.to(ori_layer_device))

Expand Down Expand Up @@ -268,7 +254,6 @@ def pack_model(
format: str,
desc_act=False,
sym: bool = True,
use_cuda_fp16=True,
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False,
):
Expand All @@ -295,7 +280,6 @@ def pack_model(
group_size,
backend=backend,
format=format,
use_cuda_fp16=use_cuda_fp16,
desc_act=desc_act,
pack=True,
)
Expand Down
1 change: 0 additions & 1 deletion tests/test_q4_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,6 @@ def test_cuda_old(self, use_half2: bool):
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear.qzeros += 0b00010001000100010001000100010001 # for new weight format
linear.use_cuda_fp16 = use_half2

# We cast twice just for the seed.
inp = torch.rand(1, m, k, dtype=torch.float16).to(device).to(weight_dtype)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def test_gptq_v1_to_v2_runtime_convert(self):

def test_gptq_v1_serialization(self):
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0")
model.quantize_config.format = FORMAT.GPTQ

with tempfile.TemporaryDirectory() as tmpdir:
model.save_quantized(tmpdir, format="gptq")
model.save_quantized(tmpdir)

with open(os.path.join(tmpdir, "quantize_config.json"), "r") as f:
quantize_config = json.load(f)
Expand Down

0 comments on commit 9238f99

Please sign in to comment.