Skip to content

Commit ef43746

Browse files
BlackSamorezAndrey PanferovSunMarcMekkCyber
authored andcommitted
New HIGGS quantization interfaces, JIT kernel compilation support. (huggingface#36148)
* new flute * new higgs working * small adjustments * progress and quallity * small updates * style --------- Co-authored-by: Andrey Panferov <panferov.andrey3@wb.ru> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
1 parent 4bcc64f commit ef43746

File tree

5 files changed

+55
-81
lines changed

5 files changed

+55
-81
lines changed

src/transformers/integrations/higgs.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,12 @@
2828

2929

3030
if is_flute_available():
31-
import flute.utils
31+
from flute.integrations.higgs import prepare_data_transposed
32+
from flute.tune import TuneMetaData, qgemm_v2
3233

3334
if is_hadamard_available():
3435
from fast_hadamard_transform import hadamard_transform
3536

36-
if is_flute_available():
37-
import flute.utils
38-
from flute.integrations.higgs import prepare_data_transposed
39-
4037

4138
def pad_to_block(tensor, dims, had_block_size, value=0):
4239
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
@@ -464,14 +461,14 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
464461

465462
# Quantize
466463
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
467-
for i in range(0, weight.shape[0], 64):
468-
codes[i : i + 64] = torch.argmax(2 * weight[i : i + 64] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
464+
for i in range(0, weight.shape[0], 16):
465+
codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
469466
del weight
470467

471468
codes = codes.reshape(codes.shape[0], -1)
472469
scales = scales / sqrt(hadamard_size)
473470

474-
weight, scales, tables, tables2 = prepare_data_transposed(
471+
weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
475472
codes,
476473
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
477474
grid.to(dtype),
@@ -480,13 +477,15 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
480477
vector_size=p,
481478
dtype=dtype,
482479
device=device,
480+
check_correctness=False,
483481
)
484482

485483
return {
486484
"weight": weight,
487485
"scales": scales,
488486
"tables": tables,
489487
"tables2": tables2.view(dtype=torch.float16),
488+
"tune_metadata": tune_metadata,
490489
}
491490

492491

@@ -508,7 +507,6 @@ def __init__(
508507
self.num_bits = num_bits
509508
self.group_size = group_size
510509
self.hadamard_size = hadamard_size
511-
self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False)
512510

513511
assert in_features % group_size == 0
514512
assert num_bits in [2, 3, 4]
@@ -531,23 +529,23 @@ def __init__(
531529
self.register_parameter("bias", None)
532530

533531
self.workspace = None # must be set externally to be reused among layers
532+
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
534533

535534
def forward(self, x):
536535
x = pad_to_block(x, [-1], self.hadamard_size)
537536

538537
if self.workspace is None:
539538
raise Exception("Workspace must be set before calling forward")
540539

541-
return flute.qgemm_hadamard(
540+
return qgemm_v2(
542541
x,
543542
self.weight,
544543
self.scales,
545544
self.tables,
546545
self.tables2.view(dtype=torch.float32),
547546
self.workspace,
548-
self.num_bits,
549-
self.group_size,
550-
self.hadamard_size,
547+
self.tune_metadata,
548+
hadamard_size=self.hadamard_size,
551549
)
552550

553551

src/transformers/quantizers/quantizer_higgs.py

Lines changed: 34 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from typing import TYPE_CHECKING, Any, Dict, List, Optional
1515

16+
from ..utils.logging import tqdm
1617
from .base import HfQuantizer
1718
from .quantizers_utils import get_module_from_name
1819

@@ -30,20 +31,6 @@
3031
logger = logging.get_logger(__name__)
3132

3233

33-
def get_num_sms_from_device(device):
34-
target_device_cc = torch.cuda.get_device_capability(device=device)
35-
if target_device_cc == (8, 6):
36-
return 84
37-
elif target_device_cc == (8, 0):
38-
return 108
39-
elif target_device_cc == (8, 9):
40-
return 128
41-
else:
42-
raise NotImplementedError(
43-
f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus"
44-
)
45-
46-
4734
class HiggsHfQuantizer(HfQuantizer):
4835
"""
4936
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
@@ -115,26 +102,24 @@ def create_quantized_param(
115102
self.quantization_config.group_size,
116103
self.quantization_config.hadamard_size,
117104
)
118-
119105
del param_value
120106

121-
module, tensor_name = get_module_from_name(model, param_name)
107+
module, _ = get_module_from_name(model, param_name)
108+
module_name = ".".join(param_name.split(".")[:-1])
122109
for key, value in flute_dict.items():
123110
if key in module._parameters:
124111
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
125112
elif key in module._buffers:
126113
module._buffers[key] = torch.nn.Buffer(value)
114+
elif key == "tune_metadata":
115+
module.tune_metadata = value
116+
self.quantization_config.tune_metadata[module_name] = value.to_dict()
127117
else:
128118
raise ValueError(f"Unexpected key {key} in module {module}")
129119

130120
if unexpected_keys is not None and param_name in unexpected_keys:
131121
unexpected_keys.remove(param_name)
132122

133-
module.num_sms_packed = torch.nn.Parameter(
134-
torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32),
135-
requires_grad=False,
136-
)
137-
138123
def _process_model_before_weight_loading(
139124
self,
140125
model: "PreTrainedModel",
@@ -149,57 +134,42 @@ def _process_model_before_weight_loading(
149134
model.config.quantization_config = self.quantization_config
150135

151136
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
152-
import flute.utils
137+
from flute.tune import TuneMetaData, maybe_tune_and_repack
138+
from flute.utils import make_workspace_streamk
153139

154140
from ..integrations import HiggsLinear
155141

156142
flute_workspaces = {}
157-
for name, module in model.named_modules():
158-
if isinstance(module, HiggsLinear):
159-
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
160-
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
161-
if module.weight.device not in flute_workspaces:
162-
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk(
163-
device=module.weight.device
164-
)
165-
module.workspace = flute_workspaces[module.weight.device]
166-
167-
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
168-
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
169-
if module.num_sms_packed.item() != get_num_sms_from_device(module.weight.device):
170-
new_device = module.weight.device
171-
new_num_sms = get_num_sms_from_device(new_device)
172-
module.weight.data = flute.utils.pack(
173-
flute.utils.unpack(
174-
weight=module.weight.data,
175-
scales=module.scales.data,
176-
workspace=module.workspace,
177-
num_bits=module.num_bits,
178-
group_size=module.group_size,
179-
num_sms_packed=module.num_sms_packed.item(),
180-
).T.contiguous(),
181-
module.num_bits,
182-
module.group_size,
183-
)
184-
module.num_sms_packed = torch.nn.Parameter(
185-
torch.tensor(new_num_sms, device=new_device, dtype=torch.int32),
186-
requires_grad=False,
187-
)
143+
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
144+
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
145+
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
146+
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
147+
if module.weight.device not in flute_workspaces:
148+
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
149+
module.workspace = flute_workspaces[module.weight.device]
150+
151+
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
152+
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
153+
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
154+
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
155+
weight=module.weight.data,
156+
scales=module.scales.data,
157+
metadata=module.tune_metadata,
158+
)
159+
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
188160

189161
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
190162
from ..integrations import HiggsLinear
191163

192-
not_missing_keys = []
193-
for name, module in model.named_modules():
194-
if isinstance(module, HiggsLinear):
195-
for missing in missing_keys:
196-
if (
197-
(name in missing or name in f"{prefix}.{missing}")
198-
and not missing.endswith(".weight")
199-
and not missing.endswith(".bias")
200-
):
201-
not_missing_keys.append(missing)
202-
return [k for k in missing_keys if k not in not_missing_keys]
164+
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
165+
166+
def should_update(key: str) -> bool:
167+
if key.endswith(".weight") or key.endswith(".bias"):
168+
return False
169+
full_key = f"{prefix}.{key}"
170+
return any(name in key or name in full_key for name in higgs_names)
171+
172+
return [key for key in missing_keys if not should_update(key)]
203173

204174
@property
205175
def is_trainable(self, model: Optional["PreTrainedModel"] = None):

src/transformers/utils/import_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def is_flax_available():
639639

640640
def is_flute_available():
641641
try:
642-
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.3.0"
642+
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.4.1"
643643
except importlib.metadata.PackageNotFoundError:
644644
return False
645645

src/transformers/utils/quantization_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,8 @@ class HiggsConfig(QuantizationConfigMixin):
14041404
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization.
14051405
group_size (int, *optional*, defaults to 256):
14061406
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size.
1407+
tune_metadata ('dict', *optional*, defaults to {}):
1408+
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default is an empty dictionary. Is set automatically during tuning.
14071409
"""
14081410

14091411
def __init__(
@@ -1413,16 +1415,20 @@ def __init__(
14131415
modules_to_not_convert: Optional[List[str]] = None,
14141416
hadamard_size: int = 512,
14151417
group_size: int = 256,
1418+
tune_metadata: Optional[Dict[str, Any]] = None,
14161419
**kwargs,
14171420
):
14181421
if modules_to_not_convert is None:
14191422
modules_to_not_convert = ["lm_head"]
1423+
if tune_metadata is None:
1424+
tune_metadata = {}
14201425
self.quant_method = QuantizationMethod.HIGGS
14211426
self.bits = bits
14221427
self.p = p
14231428
self.modules_to_not_convert = modules_to_not_convert
14241429
self.hadamard_size = hadamard_size
14251430
self.group_size = group_size
1431+
self.tune_metadata = tune_metadata
14261432

14271433
self.post_init()
14281434

tests/quantization/higgs/test_higgs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ def test_from_dict(self):
6565
@require_accelerate
6666
# @require_read_token
6767
class HiggsTest(unittest.TestCase):
68-
model_name = "meta-llama/Meta-Llama-3.1-8B"
68+
model_name = "unsloth/Llama-3.2-1B"
6969

70-
input_text = "A quick brown fox jumps over the"
70+
input_text = "Font test: A quick brown fox jumps over the"
7171
max_new_tokens = 2
7272

73-
EXPECTED_OUTPUT = "A quick brown fox jumps over the lazy dog"
73+
EXPECTED_OUTPUT = "Font test: A quick brown fox jumps over the lazy dog"
7474

7575
device_map = "cuda"
7676

0 commit comments

Comments
 (0)