Skip to content

Commit 294a78a

Browse files
committed
fix dtype, passing test
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent fb739cf commit 294a78a

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/llmcompressor/entrypoints/weights_ptq/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import torch
88
import tqdm
99
from compressed_tensors.quantization import QuantizationScheme
10+
from compressed_tensors.utils.match import _match_name
1011
from loguru import logger
1112
from safetensors.torch import load_file, save_file
1213

1314
from llmcompressor.entrypoints.weights_ptq.helpers import (
1415
gpu_if_available,
15-
is_match_name,
1616
validate_scheme,
1717
)
1818
from llmcompressor.entrypoints.weights_ptq.lifecycle import (
@@ -59,6 +59,7 @@ def ptq_weights(
5959
if is_weights_file(file_path):
6060
logger.warning(f"Skipping weights file {file_path}")
6161
save_path.parent.mkdir(parents=True, exist_ok=True)
62+
logger.info(f"Copying {file_path} {save_path}")
6263
shutil.copyfile(resolved_path, save_path)
6364

6465
# 1-4. quantize and compress weights
@@ -89,7 +90,11 @@ def _process_file(
8990
tensors = load_file(file_path)
9091

9192
for name in list(tensors.keys()):
92-
if not is_match_name(name, ["re:.*weight$"], ignore):
93+
module_name, param_name = name.rsplit(".", 1)
94+
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
95+
is_weight = param_name == "weight"
96+
if is_ignored or not is_weight:
97+
print(f"skip {name}")
9398
continue
9499

95100
# 1. initialize module with qparams (on device)
@@ -103,7 +108,7 @@ def _process_file(
103108

104109
# 4. save compressed data (on cpu)
105110
del tensors[name]
106-
prefix = name.rsplit(".", 1)[0] + "."
111+
prefix = module_name + "."
107112
for key, value in module.state_dict(prefix=prefix).items():
108113
tensors[key] = value.to("cpu")
109114

src/llmcompressor/entrypoints/weights_ptq/lifecycle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def initialize_quantized_linear(
2222
weight: torch.Tensor, scheme: QuantizationScheme, device: str | torch.device
2323
) -> torch.nn.Module:
2424
out_features, in_features = weight.shape
25-
module = torch.nn.Linear(in_features, out_features, bias=False, device=device)
25+
module = torch.nn.Linear(
26+
in_features, out_features, bias=False, device=device, dtype=weight.dtype
27+
)
2628
module.weight.data.copy_(weight)
2729
initialize_module_for_quantization(module, scheme, force_zero_point=False)
2830

0 commit comments

Comments
 (0)