@@ -792,6 +792,7 @@ def _get_dtype(
792792 sharded_metadata : Optional [dict ],
793793 state_dict : Optional [dict ],
794794 weights_only : bool ,
795+ hf_quantizer : Optional [HfQuantizer ] = None ,
795796) -> tuple [PreTrainedConfig , torch .dtype ]:
796797 """Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
797798 inferred dtype. We do the following:
@@ -840,6 +841,9 @@ def _get_dtype(
840841 # set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
841842 dtype = torch .get_default_dtype ()
842843
844+ if hf_quantizer is not None :
845+ hf_quantizer .update_dtype (dtype )
846+
843847 # Get the main dtype
844848 if isinstance (dtype , dict ):
845849 main_dtype = dtype .get ("" , torch .get_default_dtype ())
@@ -1433,7 +1437,7 @@ def tp_plan(self, plan: dict[str, str] | None):
14331437 def pp_plan (self , plan : dict [str , tuple [str , str ]]):
14341438 self ._pp_plan = plan
14351439
1436- def dequantize (self ):
1440+ def dequantize (self , dtype = None ):
14371441 """
14381442 Potentially dequantize the model in case it has been quantized by a quantization method that support
14391443 dequantization.
@@ -1443,7 +1447,7 @@ def dequantize(self):
14431447 if hf_quantizer is None :
14441448 raise ValueError ("You need to first quantize your model in order to dequantize it" )
14451449
1446- return hf_quantizer .dequantize (self )
1450+ return hf_quantizer .dequantize (self , dtype = dtype )
14471451
14481452 def _backward_compatibility_gradient_checkpointing (self ):
14491453 if self .supports_gradient_checkpointing and getattr (self .config , "gradient_checkpointing" , False ):
@@ -3875,8 +3879,8 @@ def from_pretrained(
38753879 if "attn_implementation" in kwargs :
38763880 config ._attn_implementation = kwargs .pop ("attn_implementation" )
38773881
3878- hf_quantizer , config , dtype , device_map = get_hf_quantizer (
3879- config , quantization_config , dtype , device_map , weights_only , user_agent
3882+ hf_quantizer , config , device_map = get_hf_quantizer (
3883+ config , quantization_config , device_map , weights_only , user_agent
38803884 )
38813885
38823886 if gguf_file :
@@ -3923,7 +3927,9 @@ def from_pretrained(
39233927 ]
39243928
39253929 # Find the correct dtype based on current state
3926- config , dtype = _get_dtype (dtype , checkpoint_files , config , sharded_metadata , state_dict , weights_only )
3930+ config , dtype = _get_dtype (
3931+ dtype , checkpoint_files , config , sharded_metadata , state_dict , weights_only , hf_quantizer
3932+ )
39273933
39283934 config .name_or_path = pretrained_model_name_or_path
39293935 model_init_context = cls .get_init_context (dtype , is_quantized , _is_ds_init_called )
@@ -3932,22 +3938,18 @@ def from_pretrained(
39323938 # Let's make sure we don't run the init function of buffer modules
39333939 model = cls (config , * model_args , ** model_kwargs )
39343940
3941+ if hf_quantizer is not None : # replace module with quantized modules (does not touch weights)
3942+ hf_quantizer .preprocess_model (
3943+ model = model ,
3944+ dtype = dtype ,
3945+ device_map = device_map ,
3946+ checkpoint_files = checkpoint_files ,
3947+ use_kernels = use_kernels ,
3948+ )
3949+
39353950 # Obtain the weight conversion mapping for this model if any are registered
39363951 weight_conversions = get_model_conversion_mapping (model , key_mapping , hf_quantizer )
39373952
3938- # make sure we use the model's config since the __init__ call might have copied it
3939- config = model .config
3940-
3941- if hf_quantizer is not None : # replace module with quantized modules (does not touch weights)
3942- hf_quantizer .preprocess_model (
3943- model = model ,
3944- device_map = device_map ,
3945- keep_in_fp32_modules = model ._keep_in_fp32_modules , # TODO prob no longer needed?
3946- config = config ,
3947- checkpoint_files = checkpoint_files ,
3948- use_kernels = use_kernels ,
3949- )
3950-
39513953 if _torch_distributed_available and device_mesh is not None : # add hooks to nn.Modules: no weights
39523954 model = distribute_model (model , tp_plan , distributed_config , device_mesh , tp_size )
39533955
@@ -3994,7 +3996,9 @@ def from_pretrained(
39943996
39953997 if hf_quantizer is not None :
39963998 model .hf_quantizer = hf_quantizer
3997- hf_quantizer .postprocess_model (model , config = config ) # usually a no-op but sometimes needed
3999+ hf_quantizer .postprocess_model (
4000+ model
4001+ ) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
39984002
39994003 if _adapter_model_path is not None :
40004004 adapter_kwargs ["key_mapping" ] = key_mapping
0 commit comments