Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GCU] Support llama for GCU #8445

Merged
merged 1 commit into from
May 17, 2024

Conversation

EnflameGCU
Copy link
Contributor

PR types

New features

PR changes

Models

Description

Support llama for GCU

Copy link

paddle-bot bot commented May 15, 2024

Thanks for your contribution!

Copy link

codecov bot commented May 15, 2024

Codecov Report

Attention: Patch coverage is 36.00000% with 16 lines in your changes missing coverage. Please review.

Project coverage is 54.29%. Comparing base (5170664) to head (32d66ef).
Report is 208 commits behind head on develop.

Files Patch % Lines
paddlenlp/transformers/llama/fusion_ops.py 10.00% 9 Missing ⚠️
paddlenlp/transformers/llama/modeling.py 54.54% 5 Missing ⚠️
paddlenlp/generation/utils.py 50.00% 1 Missing ⚠️
paddlenlp/utils/tools.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8445      +/-   ##
===========================================
- Coverage    55.42%   54.29%   -1.14%     
===========================================
  Files          617      617              
  Lines        96286    96340      +54     
===========================================
- Hits         53367    52303    -1064     
- Misses       42919    44037    +1118     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -1528,7 +1535,7 @@ def forward(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
is_casual = False
if self.config.use_flash_attention:
if self.config.use_flash_attention and get_env_device() != "gcu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在attention mask的处理上,GCU不一样的地方是什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

基于 use_flash_attention kernel 的实现,is_casual 情况下也是需要当前与输入相同dtypeattention_mask,而不是None或者bool类型的mask

@@ -297,6 +303,7 @@ def do_generation():
parser = get_eval_parser()
args = parser.parse_args()
paddle.set_default_dtype(args.dtype)
paddle.set_device(args.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在训练初始的位置设置set_device,这里再重新设置的原因是什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单独使用的eval.py测试,没有训练初始位置? 当然不设置默认应该也是当前的device,可以去除

@@ -934,7 +941,7 @@ def forward(
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一定要加 cos_sin 的优化吗?代码改动很大,而且会导致其他设备性能下降,凭空多了很多开销。

或者你们需要的时候再自己去造一个 cos_sin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里主要是因为算子的实现与paper或者vllm一致,使用了与这里不同的sin/cos。关于其他设备性能开销,一方面,应该大多table的计算只在初始化阶段,另一方面,我们将按照第一个issue的建议,在特定设备进行计算,这里仅仅只会多返回一个None

@wawltor wawltor merged commit d9dcd9a into PaddlePaddle:develop May 17, 2024
8 of 11 checks passed
@ZHUI
Copy link
Collaborator

ZHUI commented May 20, 2024

https://xly.bce.baidu.com/paddlepaddle/Paddle-NLP/newipipe/detail/10720664/job/26276076

这个PR的 rope 接口改动,貌似导致自动并行代码挂了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants