Skip to content

Commit

Permalink
fix lora (PaddlePaddle#7824)
Browse files Browse the repository at this point in the history
  • Loading branch information
lugimzzz committed Jan 11, 2024
1 parent d4acbfc commit b312634
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
4 changes: 2 additions & 2 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal

if self.is_pipelinemodel:
self.model._single_to_pp_mapping = None
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
merge_tensor_parallel = False
logger.warning(
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
)
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
merge_tensor_parallel = False
logger.warning(
"Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
Expand Down
7 changes: 5 additions & 2 deletions paddlenlp/quantization/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,12 @@ def convert_to_quantize_state_dict_with_check(state_dict, quantization_linear_li
raise ValueError(
f"{quant_weight_name} should be {paddle.int8} in state_dict but received dtype {state_dict[quant_weight_name].dtype}."
)
if state_dict[quant_scale_name].dtype != paddle.float32:
if (
state_dict[quant_scale_name].dtype != paddle.float16
and state_dict[quant_scale_name].dtype != paddle.bfloat16
):
raise ValueError(
f"{quant_scale_name} should be {paddle.float32} in state_dict but received dtype {state_dict[quant_scale_name].dtype}."
f"{quant_scale_name} should be {paddle.float16} or {paddle.bfloat16} in state_dict but received dtype {state_dict[quant_scale_name].dtype}."
)
elif weight_name in state_dict:
target_weight = state_dict.pop(weight_name).cast(dtype)
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,9 +1759,15 @@ def _load_pretrained_model(
loaded_keys, quantization_linear_list, config.quantization_config
)
if keep_in_fp32_modules is None:
keep_in_fp32_modules = ["quant_scale"]
keep_in_fp32_modules = (
["quant_scale"] if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] else None
)
else:
keep_in_fp32_modules += ["quant_scale"]
keep_in_fp32_modules = (
keep_in_fp32_modules + ["quant_scale"]
if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
else keep_in_fp32_modules
)

missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
Expand Down Expand Up @@ -2173,7 +2179,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
logger.info("Loaded weights file from disk, setting weights to model.")

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and dtype == "float16"
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
dtype == "float16" or dtype == "bfloat16"
)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
Expand Down

0 comments on commit b312634

Please sign in to comment.