Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for llm/gpt-3 #6570

Merged
merged 15 commits into from
Aug 3, 2023
Merged

Optimization for llm/gpt-3 #6570

merged 15 commits into from
Aug 3, 2023

Conversation

DrownFish19
Copy link
Collaborator

@DrownFish19 DrownFish19 commented Aug 1, 2023

PR types

Function optimization

PR changes

APIs and Docs

Description

Optimization for llm/gpt-3

  1. update README.md file.
  2. replace self. parameters with config. parameters in modeling.py.
  3. add _init_weights for GPTPretrainedModel.
  4. modify output_attentions (need_weights) to control attention weights output.

Replace parameters with config in MHA; Replace GPTEmbedding ParamAttr initialiizer with _init_weights;Modify fuse_attention_qkv parameter
@paddle-bot
Copy link

paddle-bot bot commented Aug 1, 2023

Thanks for your contribution!

@codecov
Copy link

codecov bot commented Aug 1, 2023

Codecov Report

Merging #6570 (b072aa6) into develop (1324998) will not change coverage.
Report is 2 commits behind head on develop.
The diff coverage is n/a.

@@           Coverage Diff            @@
##           develop    #6570   +/-   ##
========================================
  Coverage    62.94%   62.94%           
========================================
  Files          531      531           
  Lines        77727    77727           
========================================
  Hits         48923    48923           
  Misses       28804    28804           

need_weights=False, #
weight_attr=None, #
bias_attr=None, #
do_recompute=False,
Copy link
Collaborator

@ZHUI ZHUI Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看一下这些 参数的使用,应该都可以删除

        kdim=None, #
        vdim=None, #
        need_weights=False, # 
        weight_attr=None, #
        bias_attr=None, #
        do_recompute=False,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除了kdim, vdim, need_weights和bias_attr, 保留了weight_attr和do_recompute作为TransformerDecoderLayer参数接口

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embed_dim = config.hidden_size
self.embed_dim = config.hidden_size
self.kdim = kdim if kdim is not None else config.hidden_size
self.vdim = vdim if vdim is not None else config.hidden_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 kdim vdim 应该不单独传入,直接 用 config.hidden_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两处都直接使用hidden_size了,删除了kdim和vdim

need_weights=False, #
weight_attr=None, #
bias_attr=None, #
do_recompute=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if num_partitions > 1:
if config.tensor_parallel_degree > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    assert self.num_heads % config.tensor_parallel_degree == 0
    self.num_heads = self.num_heads // config.tensor_parallel_degree

if isinstance(layer, (nn.Linear,
nn.Embedding,
fleet.meta_parallel.VocabParallelEmbedding)):
# In the dygraph mode, use the `set_value` to reset the parameter directly,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处不全,参考llama

@@ -682,6 +669,17 @@ def get_tensor_parallel_split_mappings(num_layers):
"layers.0.linear2.weight": partial(fn, is_column=False),
}

if config.fuse_attention_qkv:
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单卡。tp=2前向精度。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.use_flash_attn = config.use_flash_attn if flash_attention else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_flash_attention

使用下面脚本,即可在llama-7b的基础上,继续训练.
注意:
1. 需要paddle develop版本训练,需要安装`pip install tool_helpers visualdl==2.5.3`等相关缺失whl包
2. `use_flash_attn` 需要在A100机器开启,否则loss可能不正常(很快变成0.00x,非常小不正常)。建议使用cuda11.8环境。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
2. `use_flash_attn` 需要在A100机器开启,否则loss可能不正常(很快变成0.00x,非常小不正常)。建议使用cuda11.8环境。
2. `use_flash_attention` 需要在A100机器开启,否则loss可能不正常(很快变成0.00x,非常小不正常)。建议使用cuda11.8环境。

export PYTHONPATH="../../PaddleNLP/"
export FLAGS_cudnn_deterministic=True
log_dir="log"
rm -rf $log_dir

python -u -m paddle.distributed.launch \
--gpus "0" \
--gpus "6,7" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
--gpus "6,7" \
--gpus "0" \


if config.tensor_parallel_degree > 1:
assert config.num_attention_heads % config.tensor_parallel_degree == 0
config.num_attention_heads = config.num_attention_heads // config.tensor_parallel_degree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config.num_attention_heads = config.num_attention_heads // config.tensor_parallel_degree
self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_degree

修改了原始变量的地方,建议重新赋值一遍。不要直接修改config

@@ -270,10 +233,10 @@ def gen_cache(self, key, value=None, type=Cache):
return self.StaticCache(k, v)
elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like(
input=key, shape=[-1, self.num_heads, 0, self.head_dim], dtype=key.dtype, value=0
input=key, shape=[-1, self.config.num_attention_heads, 0, self.head_dim], dtype=key.dtype, value=0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input=key, shape=[-1, self.config.num_attention_heads, 0, self.head_dim], dtype=key.dtype, value=0
input=key, shape=[-1, self.num_attention_heads, 0, self.head_dim], dtype=key.dtype, value=0

)
v = layers.fill_constant_batch_size_like(
input=key, shape=[-1, self.num_heads, 0, self.head_dim], dtype=key.dtype, value=0
input=key, shape=[-1, self.config.num_attention_heads, 0, self.head_dim], dtype=key.dtype, value=0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input=key, shape=[-1, self.config.num_attention_heads, 0, self.head_dim], dtype=key.dtype, value=0
input=key, shape=[-1, self.num_attention_heads, 0, self.head_dim], dtype=key.dtype, value=0

# Recompute defaults to False and is controlled by Trainer
self.enable_recompute = False

config.use_flash_attention = config.use_flash_attention if flash_attention else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config.use_flash_attention = config.use_flash_attention if flash_attention else None
self.use_flash_attention = config.use_flash_attention if flash_attention else None


out = paddle.matmul(weights, v)

# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, -1])

return (out, weights) if self.need_weights else out
return (out, weights) if self.config.need_weights else out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L299C9-L299C26

参考此处,self.config.need_weights 换成,forward 函数中的 output_attentions 参数。

Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZHUI ZHUI merged commit 435cb4f into PaddlePaddle:develop Aug 3, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants