Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFRACTOR] remove use_cuda_fp16 argument #97

Merged
merged 2 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/benchmark/generation_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def load_model_tokenizer(
model = GPTQModel.from_quantized(
model_name_or_path,
max_memory=max_memory,
use_cuda_fp16=True,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
Expand Down
2 changes: 0 additions & 2 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def from_quantized(
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
backend: Backend = Backend.AUTO,
use_cuda_fp16: bool = True,
quantize_config: Optional[QuantizeConfig | Dict] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
Expand All @@ -138,7 +137,6 @@ def from_quantized(
max_memory=max_memory,
device=device,
backend=backend,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
Expand Down
13 changes: 0 additions & 13 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,6 @@ def quantize(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,

# TODO: remove use_cuda_fp16 arg..why? doesn't pass smell test @ZX-ModelCloud
use_cuda_fp16: bool = True,

autotune_warmup_after_quantized: bool = False,
calibration_enable_gpu_cache: bool = True,
):
Expand Down Expand Up @@ -424,7 +420,6 @@ def tmp(_, inp, out):
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
backend=Backend.AUTO,
use_cuda_fp16=use_cuda_fp16,
desc_act=self.quantize_config.desc_act,
warmup_triton=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu,
Expand Down Expand Up @@ -773,11 +768,8 @@ def from_quantized(
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,

backend: Backend = Backend.AUTO,

torch_dtype: [str | torch.dtype] = "auto",
use_cuda_fp16: bool = True,
quantize_config: Optional[QuantizeConfig] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
Expand Down Expand Up @@ -917,10 +909,6 @@ def from_quantized(
def skip(*args, **kwargs):
pass

if torch_dtype != torch.float16:
logger.warning("Overriding use_cuda_fp16 to False since torch_dtype is not torch.float16.")
use_cuda_fp16 = False

torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
Expand Down Expand Up @@ -962,7 +950,6 @@ def skip(*args, **kwargs):
quantize_config.group_size,
backend=backend.AUTO if backend == Backend.MARLIN or backend == Backend.BITBLAS else backend,
format=FORMAT.GPTQ_V2,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
)
model.tie_weights()
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
infeatures: int ,
outfeatures: int,
bias: bool,
use_cuda_fp16=True,
kernel_switch_threshold=128,
weight_dtype=torch.float16,
**kwargs,
Expand Down Expand Up @@ -71,7 +70,7 @@ def __init__(
self.bias = None
self.half_indim = self.infeatures // 2

self.use_cuda_fp16 = use_cuda_fp16 if bits != 8 else False
self.use_cuda_fp16 = weight_dtype == torch.float16 if bits != 8 else False

# is performed by unpacking the weights and using torch.matmul
if self.bits in [2, 4, 8]:
Expand Down Expand Up @@ -195,7 +194,7 @@ def forward(self, x):
if self.use_cuda_fp16:
if x_dtype != torch.float16:
logger.warning_once(
f"The cuda-old kernel for GPTQ with use_cuda_fp16=True requires a float16 input activation, while {x_dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
f"The cuda-old kernel for GPTQ requires a float16 input activation, while {x_dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
)

if self.bits == 2:
Expand Down
36 changes: 10 additions & 26 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def make_quant(
format: str,
desc_act: bool = False,
sym: bool = True,
use_cuda_fp16: bool = True,
pack: bool = False,
) -> BaseQuantLinear:
select_quant_linear_func = select_quant_linear_with_pack if pack else select_quant_linear
Expand Down Expand Up @@ -143,29 +142,16 @@ def make_quant(
raise NotImplementedError(f"Unsupported module {submodule}")

bias = submodule.bias is not None
if (not (desc_act) or group_size == -1) and backend != Backend.TRITON:
new_layer = QuantLinear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
use_cuda_fp16=use_cuda_fp16,
weight_dtype=submodule.weight.dtype,
)
else:
new_layer = QuantLinear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
weight_dtype=submodule.weight.dtype,
)
new_layer = QuantLinear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
weight_dtype=submodule.weight.dtype,
)
new_layer.device = ori_layer_device
recurse_setattr(module, name, new_layer.to(ori_layer_device))

Expand Down Expand Up @@ -268,7 +254,6 @@ def pack_model(
format: str,
desc_act=False,
sym: bool = True,
use_cuda_fp16=True,
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False,
):
Expand All @@ -295,7 +280,6 @@ def pack_model(
group_size,
backend=backend,
format=format,
use_cuda_fp16=use_cuda_fp16,
desc_act=desc_act,
pack=True,
)
Expand Down
1 change: 0 additions & 1 deletion tests/test_q4_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,6 @@ def test_cuda_old(self, use_half2: bool):
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear.qzeros += 0b00010001000100010001000100010001 # for new weight format
linear.use_cuda_fp16 = use_half2

# We cast twice just for the seed.
inp = torch.rand(1, m, k, dtype=torch.float16).to(device).to(weight_dtype)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def test_gptq_v1_to_v2_runtime_convert(self):

def test_gptq_v1_serialization(self):
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0")
model.quantize_config.format = FORMAT.GPTQ

with tempfile.TemporaryDirectory() as tmpdir:
model.save_quantized(tmpdir, format="gptq")
model.save_quantized(tmpdir)

with open(os.path.join(tmpdir, "quantize_config.json"), "r") as f:
quantize_config = json.load(f)
Expand Down