Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DCU] new features #63721

Merged
merged 27 commits into from
Jun 13, 2024
Merged

[DCU] new features #63721

merged 27 commits into from
Jun 13, 2024

Conversation

yuguo-Jack
Copy link
Contributor

@yuguo-Jack yuguo-Jack commented Apr 21, 2024

PR Category

Performance Optimization

PR Types

New features

Description

surpport multiclass_nms3 op for DCU(单测通过)
surpport miopen bn for DCU when FLAGS_cudnn_batchnorm_spatial_persistent is 1(test_batch_norm_op/test_batch_norm_op_v2单测通过)
surpport gemm fp16 compute type for DCU when FLAGS_gemm_use_half_precision_compute_type is 1

支持flash attention(mha,gqa前反向,单测通过)
支持block attention相关算子(支持prefix precache,单测通过)
支持a8w8相关算子(单测通过)
支持quant_linear相关算子(单测通过)
支持kv cache int8相关算子(单测通过)
支持weight only量化反量化相关算子(单测通过)
支持fused rope相关算子(单测通过)

Copy link

paddle-bot bot commented Apr 21, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 21, 2024
Copy link

paddle-ci-bot bot commented May 2, 2024

Sorry to inform you that 3d36c6f's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link

paddle-ci-bot bot commented May 29, 2024

Sorry to inform you that 5326497's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

qili93
qili93 previously approved these changes Jun 12, 2024
Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qili93 qili93 requested a review from XiaoguangHu01 June 12, 2024 07:53
@@ -106,7 +106,12 @@ void per_channel_quant(int8_t* output,
static_cast<float>(current_weight_row[input_idx]);
const float scaled_weight = round(weight_elt / col_scale);
int int_weight = static_cast<int>(scaled_weight);
#ifdef PADDLE_WITH_HIP
const int8_t clipped_weight =
std::max(-7, std::min(7, int_weight)) + 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要+8的原因是什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为int4反量化时会将高低四位按uint8解析,所以量化时将数值移至1-15,反量化再减去,会方便int4反量化kernel的实现,否则判断符号位这类操作会产生线程束分化。因为weight only在dcu上没有优化强转,所以weight only量化反量化流程和nv有较大不同。

@@ -255,25 +255,42 @@ void FlashAttnUnpaddedGradBaseKernel(
kdq = &dq_tmp;
}

#ifdef PADDLE_WITH_HIP
std::initializer_list<int64_t> dk_dv_input_shape = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里dk_dv_input_shape和dk_dv_shape的区别是什么,形状不相同的原因是什么?
是否可以用相同的变量名?

Copy link
Contributor Author

@yuguo-Jack yuguo-Jack Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gqa逻辑是参考我们适配的fa 2.0.4的cpp接口进行的修改,以定长接口为例:
image
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经按照建议修复

Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qili93 qili93 merged commit c66533f into PaddlePaddle:develop Jun 13, 2024
30 of 33 checks passed
qili93 added a commit that referenced this pull request Jun 13, 2024
yuanlehome pushed a commit that referenced this pull request Jun 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants