77import torch
88import tqdm
99from compressed_tensors .quantization import QuantizationScheme
10+ from compressed_tensors .utils .match import _match_name
1011from loguru import logger
1112from safetensors .torch import load_file , save_file
1213
1314from llmcompressor .entrypoints .weights_ptq .helpers import (
1415 gpu_if_available ,
15- is_match_name ,
1616 validate_scheme ,
1717)
1818from llmcompressor .entrypoints .weights_ptq .lifecycle import (
@@ -59,6 +59,7 @@ def ptq_weights(
5959 if is_weights_file (file_path ):
6060 logger .warning (f"Skipping weights file { file_path } " )
6161 save_path .parent .mkdir (parents = True , exist_ok = True )
62+ logger .info (f"Copying { file_path } { save_path } " )
6263 shutil .copyfile (resolved_path , save_path )
6364
6465 # 1-4. quantize and compress weights
@@ -89,7 +90,11 @@ def _process_file(
8990 tensors = load_file (file_path )
9091
9192 for name in list (tensors .keys ()):
92- if not is_match_name (name , ["re:.*weight$" ], ignore ):
93+ module_name , param_name = name .rsplit ("." , 1 )
94+ is_ignored = any (_match_name (module_name , ign ) for ign in ignore )
95+ is_weight = param_name == "weight"
96+ if is_ignored or not is_weight :
97+ print (f"skip { name } " )
9398 continue
9499
95100 # 1. initialize module with qparams (on device)
@@ -103,7 +108,7 @@ def _process_file(
103108
104109 # 4. save compressed data (on cpu)
105110 del tensors [name ]
106- prefix = name . rsplit ( "." , 1 )[ 0 ] + "."
111+ prefix = module_name + "."
107112 for key , value in module .state_dict (prefix = prefix ).items ():
108113 tensors [key ] = value .to ("cpu" )
109114
0 commit comments