Skip to content

Commit

Permalink
fix npu sft ckpt load bug and no FA bug
Browse files Browse the repository at this point in the history
  • Loading branch information
NINGBENZHE committed May 14, 2024
1 parent 05acad5 commit 8a0d92a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 57 deletions.
89 changes: 38 additions & 51 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,43 @@ def main():
weight_double_quant_block_size=model_args.weight_double_quant_block_size,
)

model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
from_aistudio=model_args.from_aistudio,
quantization_config=quantization_config,
)
if hasattr(model_config, "use_flash_attention"):
model_config.use_flash_attention = model_args.use_flash_attention

model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
model_config.recompute_granularity = model_args.recompute_granularity
model_config.virtual_pp_degree = model_args.virtual_pp_degree
model_config.sequence_parallel = model_args.sequence_parallel
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
model_config.use_fused_rope = model_args.use_fused_rope

model_config.no_recompute_layers = model_args.no_recompute_layers
model_config.pp_recompute_interval = model_args.pp_recompute_interval
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
model_config.use_recompute = training_args.recompute

model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank

# Config for model using dropout, such as GPT.
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob

model_config.sep_parallel_degree = training_args.sep_parallel_degree
model_config.tensor_parallel_output = True
model_config.seq_length = data_args.max_length

if training_args.pipeline_parallel_degree > 1:
if data_args.eval_with_do_generation and training_args.do_eval:
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
Expand All @@ -145,63 +182,13 @@ def main():
if not training_args.autotuner_benchmark:
model = AutoModelForCausalLMPipe.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
use_flash_attention=model_args.use_flash_attention,
dtype=dtype,
config=model_config,
from_aistudio=model_args.from_aistudio,
quantization_config=quantization_config,
)
else:
# NOTE(gongenlei): new add autotuner_benchmark
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
from_aistudio=model_args.from_aistudio,
quantization_config=quantization_config,
)
model = AutoModelForCausalLMPipe.from_config(model_config, dtype=dtype)
else:
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
from_aistudio=model_args.from_aistudio,
quantization_config=quantization_config,
)
if hasattr(model_config, "use_flash_attention"):
model_config.use_flash_attention = model_args.use_flash_attention

model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
model_config.recompute_granularity = model_args.recompute_granularity
model_config.virtual_pp_degree = model_args.virtual_pp_degree
model_config.sequence_parallel = model_args.sequence_parallel
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
model_config.use_fused_rope = model_args.use_fused_rope

model_config.no_recompute_layers = model_args.no_recompute_layers
model_config.pp_recompute_interval = model_args.pp_recompute_interval
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
model_config.use_recompute = training_args.recompute

model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank

# Config for model using dropout, such as GPT.
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob

model_config.sep_parallel_degree = training_args.sep_parallel_degree
model_config.tensor_parallel_output = True
model_config.seq_length = data_args.max_length
if not training_args.autotuner_benchmark:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand Down
12 changes: 6 additions & 6 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,9 +1381,9 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
input_shape, past_key_values_length=past_key_values_length
)
if get_env_device() == "npu":
expanded_attn_mask = expanded_attn_mask.astype("bool")
combined_attention_mask = combined_attention_mask.astype("bool")
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")

Check warning on line 1384 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1384

Added line #L1384 was not covered by tests
else:
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
elif len(attention_mask.shape) == 3:
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
Expand All @@ -1394,9 +1394,9 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
if get_env_device() == "npu":
x = paddle.to_tensor(0.0, dtype="float16")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
expanded_attn_mask = expanded_attn_mask.astype("float16")
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = expanded_attn_mask.astype("float32")

Check warning on line 1399 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1397-L1399

Added lines #L1397 - L1399 were not covered by tests
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
Expand Down
6 changes: 6 additions & 0 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
from paddle.distributed.fleet.utils import recompute

from paddlenlp.utils.tools import get_env_device
from paddlenlp.transformers.model_utils import PipelinePretrainedModel

from .modeling import (
Expand Down Expand Up @@ -153,6 +154,11 @@ def forward(self, args):
attention_mask, (batch_size, seq_length), 0, input_embeds.dtype
)
attention_mask.stop_gradient = True
if get_env_device() == "npu":
attention_mask = attention_mask.astype("bool")
elif get_env_device() == "npu":
attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool"))
attention_mask.stop_gradient = True

Check warning on line 161 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L157-L161

Added lines #L157 - L161 were not covered by tests

if self.config.alibi and attention_mask is None:
attention_mask = LlamaModel._prepare_decoder_attention_mask(
Expand Down

0 comments on commit 8a0d92a

Please sign in to comment.