Skip to content

Conversation

@zhangboSJTU
Copy link
Contributor

PR types

Bug fixes

PR changes

OPs

Description

fix compile error in xpu2

@paddle-bot
Copy link

paddle-bot bot commented May 6, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhangboSJTU zhangboSJTU requested a review from AnnaTrainingG May 6, 2023 04:57
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
#ifdef PADDLE_WITH_XPU_KP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

89行的vec_size只与out有关吗? 看你修改前的代码是与in/out同时有关的,不确定这里会不会隐藏性能问题

Copy link
Contributor Author

@zhangboSJTU zhangboSJTU May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考elementwise也是只取了out的vec_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elementwise是因为dim是相同的,而broadcast 输入输出的dim可能是不同的……

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vec_size 原本是取 min( in out 4), 现在是取min( out 4),那应该值是>=之前的值,所以应该不会造成性能下降,有其他原因考虑需要加上吗

__simd__ ArgsT args[VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];

#ifdef PADDLE_WITH_XPU_KP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要单独区分kp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPUKP 在broadcast的功能与GPU是一样的呀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里之前铭书对GPU的broadcast进行了特化优化(减少了其中重复的fast_divmod计算),这里为了保持其优化效果,就需要单独拿出来

} else {
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
ins, args, configs, use_broadcast, block_offset, num, numel);
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

read_lens是给XPU KP 使用的,此处代码已经被else包含为什么还要添加read_lens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上一个 comment 中,gpu部分做了特化,但 kp和 gpu 使用的是相同的非特化函数,参数就需要保持一致了

}
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
std::get<Index>(dst[idx]) = in_temp[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里错误! read_lens表示向量化读取的数量的格式,说明最终往dst里面存的数据应给是read_len个,而此处的只是循环写了in_temp 错误!!!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有 is boundary 的判断条件,不确定 read_lens 和Nx的关系,不知道具体是怎么读取数据,现在清楚已经修改

@zhangboSJTU zhangboSJTU requested a review from AnnaTrainingG May 9, 2023 02:42
Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for CI-OP-Benchmark

@AnnaTrainingG AnnaTrainingG changed the title fix kp compile bug Fix xpu2 kp compile error May 9, 2023
@AnnaTrainingG AnnaTrainingG merged commit 8d340ee into PaddlePaddle:develop May 9, 2023
zhangboSJTU added a commit to zhangboSJTU/Paddle that referenced this pull request May 9, 2023
XiaoguangHu01 pushed a commit that referenced this pull request May 10, 2023
…to Release/2.5 (#53623)

* Support different dtypes of inputs for broadcast for dropout optimization  (#52093)

* change judgement for DropoutGradGPUKernelDriver

* add UnrollerWithoutVecSize and after this Loaddata to be refined

* pass unittest

* use same unroller with XPU

* BroadcastWithInt64Index

* BroadcastDataLoader template partial specialization

* fix compile errs in ROCms

* PR comment

* dropout_nd_optimization (#51479)

* with printf

* add DropOutNdForwardKernel

* PR comment

* Dropout optimize & clean broadcast inT and ElementwiseType (#52969)

* change judgement for DropoutGradGPUKernelDriver

* add UnrollerWithoutVecSize and after this Loaddata to be refined

* pass unittest

* use same unroller with XPU

* BroadcastWithInt64Index

* BroadcastDataLoader template partial specialization

* fix compile errs in ROCms

* clean ElementwiseT and InT for BroadcastKernel

* default axis and clean inT

* remove redundant fast divmod computation

* optimize drop_nd & drop_nd_grad

* optimize BroadcastDataLoader bf16 fp16

* rm InT etc. after merge develop

* delete constexpr for windows ci

* fix conflict

* fix conflic with develop

* fix conflic

* new clean

* clean

* Fix xpu2 kp compile error (#53548)

* fix conflict

* conflict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants