diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 33cfd32e9307..99b1411f0650 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -230,12 +230,12 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal if self.is_pipelinemodel: self.model._single_to_pp_mapping = None - if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1: + if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1: merge_tensor_parallel = False logger.warning( "Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False." ) - if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1: + if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1: merge_tensor_parallel = False logger.warning( "Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False." diff --git a/paddlenlp/quantization/quantization_utils.py b/paddlenlp/quantization/quantization_utils.py index 1bbc3c4ade9b..fe46efd2a2fa 100644 --- a/paddlenlp/quantization/quantization_utils.py +++ b/paddlenlp/quantization/quantization_utils.py @@ -109,9 +109,12 @@ def convert_to_quantize_state_dict_with_check(state_dict, quantization_linear_li raise ValueError( f"{quant_weight_name} should be {paddle.int8} in state_dict but received dtype {state_dict[quant_weight_name].dtype}." ) - if state_dict[quant_scale_name].dtype != paddle.float32: + if ( + state_dict[quant_scale_name].dtype != paddle.float16 + and state_dict[quant_scale_name].dtype != paddle.bfloat16 + ): raise ValueError( - f"{quant_scale_name} should be {paddle.float32} in state_dict but received dtype {state_dict[quant_scale_name].dtype}." + f"{quant_scale_name} should be {paddle.float16} or {paddle.bfloat16} in state_dict but received dtype {state_dict[quant_scale_name].dtype}." ) elif weight_name in state_dict: target_weight = state_dict.pop(weight_name).cast(dtype) diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 1df4c8ddd995..b27dce92116d 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -1776,9 +1776,15 @@ def _load_pretrained_model( loaded_keys, quantization_linear_list, config.quantization_config ) if keep_in_fp32_modules is None: - keep_in_fp32_modules = ["quant_scale"] + keep_in_fp32_modules = ( + ["quant_scale"] if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] else None + ) else: - keep_in_fp32_modules += ["quant_scale"] + keep_in_fp32_modules = ( + keep_in_fp32_modules + ["quant_scale"] + if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] + else keep_in_fp32_modules + ) missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) @@ -2200,7 +2206,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): logger.info("Loaded weights file from disk, setting weights to model.") # Check if `_keep_in_fp32_modules` is not None - use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and dtype == "float16" + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + dtype == "float16" or dtype == "bfloat16" + ) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]