@@ -702,8 +702,14 @@ def _prepare_weights(self, model_name_or_path: str,
702702
703703 return hf_weights_files , matched_pattern == "*.safetensors"
704704
705+ def _hf_weight_iter (self , hf_weights_files , use_safetensors : bool ):
706+ if use_safetensors :
707+ return safetensors_weights_iterator (hf_weights_files )
708+ else :
709+ return pt_weights_iterator (hf_weights_files )
710+
705711 def _get_quantized_weights_iterator (
706- self , model_name_or_path : str , revision : Optional [str ]
712+ self , model_name_or_path : str , revision : Optional [str ], pre_quant : bool
707713 ) -> Tuple [Generator [Tuple [str , torch .Tensor ], None , None ], Dict [str ,
708714 Any ]]:
709715 """Get an iterator to the model weights with bitsandbytes quantization,
@@ -712,6 +718,7 @@ def _get_quantized_weights_iterator(
712718 # only load the bitsandbytes module when needed
713719 try :
714720 import bitsandbytes
721+ from bitsandbytes .functional import QuantState
715722 if bitsandbytes .__version__ < "0.42.0" :
716723 raise ImportError ("bitsandbytes version is wrong. Please "
717724 "install bitsandbytes>=0.42.0." )
@@ -725,17 +732,63 @@ def _get_quantized_weights_iterator(
725732 model_name_or_path , revision )
726733
727734 quant_state_dict = {}
728- if use_safetensors :
729- weight_iterator = safetensors_weights_iterator (hf_weights_files )
730- else :
731- weight_iterator = pt_weights_iterator (hf_weights_files )
732735
733- def generator ():
736+ def quantized_checkpoint () -> Generator :
737+ # First iterate over all quant state weights
738+ weight_iterator = self ._hf_weight_iter (hf_weights_files ,
739+ use_safetensors )
740+ temp_state_dict = {}
734741 for weight_name , weight_tensor in weight_iterator :
742+ if weight_name .endswith (".weight" ):
743+ continue
744+ # TODO: only nf4 quantization is supported for now
745+ if weight_name .endswith (".quant_state.bitsandbytes__fp4" ):
746+ raise NotImplementedError (
747+ "Only bitsandbytes_nf4 quantization"
748+ f"is supported for now. { weight_name } is fp4 quantized"
749+ )
750+ temp_state_dict [weight_name ] = weight_tensor
751+
752+ # Closure to parse quant_state for each prequant weight
753+ def _parse_quant_state (param_name : str ,
754+ temp_state_dict : Dict ) -> QuantState :
755+ quant_state = {}
756+ for k in temp_state_dict :
757+ if param_name + "." in k :
758+ quant_state [k ] = temp_state_dict [k ]
759+ # bitsandbytes library requires
760+ # weight.quant_state.bitsandbytes__nf4 in CPU
761+ quant_state [param_name +
762+ ".quant_state.bitsandbytes__nf4" ] = quant_state [
763+ param_name +
764+ ".quant_state.bitsandbytes__nf4" ].cpu ().data
765+ return QuantState .from_dict (quant_state , device = "cuda" )
766+
767+ # Second iterate over all prequant and normal weights
768+ # pre quantized weights would have a quant_state
769+ for weight_name , weight_tensor in self ._hf_weight_iter (
770+ hf_weights_files , use_safetensors ):
771+ # Filter out all weights whose suffix is not ".weight"
772+ if not weight_name .endswith (".weight" ):
773+ continue
774+ if weight_name + ".quant_state.bitsandbytes__nf4" \
775+ in temp_state_dict :
776+ quant_state = _parse_quant_state (weight_name ,
777+ temp_state_dict )
778+ weight_name = weight_name .replace (".weight" , ".qweight" )
779+ quant_state_dict [weight_name ] = quant_state
780+ yield weight_name .replace (".weight" ,
781+ ".qweight" ), weight_tensor
782+ else :
783+ yield weight_name , weight_tensor
784+
785+ def generator () -> Generator :
786+ for weight_name , weight_tensor in self ._hf_weight_iter (
787+ hf_weights_files , use_safetensors ):
735788 if any (target_module in weight_name
736789 for target_module in self .target_modules ):
737790 weight_name = weight_name .replace (".weight" , ".qweight" )
738- # bitsandbytes requires data in GPU
791+ # bitsandbytes requires data in GPU
739792 loaded_weight = weight_tensor .cuda ().data
740793 with set_default_torch_dtype (torch .float32 ):
741794 processed_weight , quant_state = quantize_4bit (
@@ -749,6 +802,8 @@ def generator():
749802
750803 yield weight_name , processed_weight
751804
805+ if pre_quant :
806+ return quantized_checkpoint (), quant_state_dict
752807 return generator (), quant_state_dict
753808
754809 def _load_weights (self , model_config : ModelConfig ,
@@ -766,12 +821,21 @@ def _load_weights(self, model_config: ModelConfig,
766821 logger .info ("Loading weights with BitsAndBytes quantization. "
767822 " May take a while ..." )
768823
769- qweight_iterator , quant_state_dict = (
770- self ._get_quantized_weights_iterator (model_config .model ,
771- model_config .revision ))
824+ is_quantized_checkpoint = False
825+ quant_config = getattr (model_config .hf_config , "quantization_config" ,
826+ None )
827+ if quant_config is not None and quant_config .get (
828+ 'quant_method' ) == "bitsandbytes" :
829+ is_quantized_checkpoint = True
830+
831+ qweight_iterator , quant_state_dict = \
832+ self ._get_quantized_weights_iterator (
833+ model_config .model , model_config .revision , is_quantized_checkpoint )
772834
773835 model .load_weights (qweight_iterator )
774836
837+ torch .cuda .empty_cache ()
838+
775839 param_dict = dict (model .named_parameters ())
776840 stacked_quant_state_dict : Dict [str , Dict [int , Any ]] = {}
777841 for quant_param_name in quant_state_dict :
@@ -809,9 +873,9 @@ def _load_weights(self, model_config: ModelConfig,
809873 f"pack_factor not set for parameter { param_name } ." )
810874
811875 num_elements = [0 ] * len (quant_states )
812- for seq , quant_state in enumerate ( quant_states .items () ):
876+ for seq , quant_state in quant_states .items ():
813877 num_elements [seq ] = math .prod (
814- quant_state [ 1 ] .shape ) // pack_ratio
878+ quant_state .shape ) // pack_ratio
815879
816880 offsets = np .concatenate (([0 ], np .cumsum (num_elements )))
817881 set_weight_attrs (param , {"bnb_shard_offsets" : offsets })
0 commit comments