Skip to content

Commit

Permalink
remove triton warmup (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qubitium authored Jul 10, 2024
1 parent 4d622bd commit 146c1e5
Show file tree
Hide file tree
Showing 7 changed files with 3 additions and 91 deletions.
6 changes: 0 additions & 6 deletions examples/benchmark/generation_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def load_model_tokenizer(
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=False,
backend=backend,
)

Expand Down Expand Up @@ -279,11 +278,6 @@ def main():
logger.info(f"model quantized: {model.quantized}")
logger.info(f"quantize config: {model.quantize_config.to_dict()}")
logger.info(f"model device map: {model.hf_device_map}")

if args.backend == BACKEND.TRITON:
logger.info("warmup triton, this may take a while.")
model.warmup_triton()

logger.info("loading data")
examples = load_data(
tokenizer,
Expand Down
2 changes: 0 additions & 2 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def from_quantized(
model_basename: Optional[str] = None,
use_safetensors: bool = True,
trust_remote_code: bool = False,
warmup_triton: bool = False,
# verify weight files matches predefined hash during loading
# usage: hash_format:hash_value, example: md5:ugkdh232
# supports all hashlib hash methods
Expand All @@ -136,7 +135,6 @@ def from_quantized(
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
verify_hash=verify_hash,
**kwargs,
)
Expand Down
23 changes: 2 additions & 21 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,18 @@ def quantize(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
autotune_warmup_after_quantized: bool = False,
calibration_enable_gpu_cache: bool = True,
):
if isinstance(self.quantize_config, AutoRoundQuantizeConfig):
self._quantize(calibration_dataset, batch_size, autotune_warmup_after_quantized,
calibration_enable_gpu_cache)
self._quantize(calibration_dataset, batch_size, calibration_enable_gpu_cache)
else:
with torch.inference_mode():
self._quantize(calibration_dataset, batch_size, autotune_warmup_after_quantized, calibration_enable_gpu_cache)
self._quantize(calibration_dataset, batch_size, calibration_enable_gpu_cache)

def _quantize(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
autotune_warmup_after_quantized: bool = False,
calibration_enable_gpu_cache: bool = True,
):
if self.quantized:
Expand Down Expand Up @@ -551,7 +548,6 @@ def tmp(_, inp, out):
# triton can support 2, 4, 8bits while exllama packer only supports 4bits
backend=BACKEND.TRITON if not isinstance(self.quantize_config, AutoRoundQuantizeConfig) and self.quantize_config.format in [FORMAT.GPTQ, FORMAT.GPTQ_V2] and self.quantize_config.bits != 4 else BACKEND.AUTO,
desc_act=self.quantize_config.desc_act,
warmup_triton=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu,
format=self.quantize_config.format,
)
Expand Down Expand Up @@ -879,7 +875,6 @@ def from_quantized(
model_basename: Optional[str] = None,
use_safetensors: bool = True,
trust_remote_code: bool = False,
warmup_triton: bool = False,
format: Optional[FORMAT] = None,
allow_unsafe_loading: bool = False,
verify_hash: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -1248,27 +1243,13 @@ def skip(*args, **kwargs):

model.eval()

# == step6: (optional) warmup triton == #
if backend == BACKEND.TRITON and warmup_triton:
from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear

TritonV2QuantLinear.warmup(model, seqlen=model.seqlen)

return cls(
model,
quantized=True,
quantize_config=quantize_config,
qlinear_kernel=qlinear_kernel,
)

def warmup_triton(self, enabled: bool = True):
if not enabled:
return

from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear

TritonV2QuantLinear.warmup(self.model, seqlen=self.model.seqlen)

def __getattr__(self, item):
try:
return super().__getattr__(item)
Expand Down
43 changes: 0 additions & 43 deletions gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,48 +135,5 @@ def forward(self, x):
out = out + self.bias if self.bias is not None else out
return out

@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm

kn_values = {}

for _, m in model.named_modules():
if not isinstance(m, cls):
continue

k = m.infeatures
n = m.outfeatures

if (k, n) not in kn_values:
kn_values[(k, n)] = (
m.qweight,
m.scales,
m.qzeros,
m.g_idx,
m.bits,
m.maxq,
)

logger.info(f"Found {len(kn_values)} unique KN Linear values.")
logger.info("Warming up autotune cache ...")
with torch.no_grad():
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2**m
for (k, n), (
qweight,
scales,
qzeros,
g_idx,
bits,
maxq,
) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values


__all__ = ["TritonV2QuantLinear"]
11 changes: 1 addition & 10 deletions gptqmodel/nn_modules/triton_utils/custom_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,7 @@ def prune_configs(self, kwargs):
return pruned_configs

def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
pass


def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
Expand Down
6 changes: 0 additions & 6 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def pack_model(
format: str,
desc_act=False,
sym: bool = True,
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False,
):
QuantLinear = select_quant_linear_with_pack(
Expand Down Expand Up @@ -313,11 +312,6 @@ def pack_model(

logger.info("Model packed.")

if backend == BACKEND.TRITON and warmup_triton:
logger.warning(
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the whole model."
)
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen)
return QuantLinear

def verify_model_hash(file_path: str, verify_hash: str):
Expand Down
3 changes: 0 additions & 3 deletions tests/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,9 @@ def get_model_and_tokenizer(

model = GPTQModel.from_quantized(
model_id,
disable_exllamav2=True,
disable_exllama=True,
**model_kwargs,
)

model.warmup_triton()
return model, tokenizer


Expand Down

0 comments on commit 146c1e5

Please sign in to comment.