-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN]add matmul grad fuse pattern #67466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN]add matmul grad fuse pattern #67466
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… add_matmul_grad_fuse_pattern
Sorry to inform you that 7fd453c's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
… add_matmul_grad_fuse_pattern
bool have_sum_op = false; | ||
pir::Value sum_output; | ||
pir::Value sum_input; | ||
for (auto user_it = dout.use_begin(); user_it != dout.use_end(); | ||
++user_it) { | ||
if (!user_it->owner()) { | ||
continue; | ||
} | ||
if (auto sum_op = user_it->owner()->dyn_cast<paddle::dialect::SumOp>()) { | ||
have_sum_op = true; | ||
sum_output = sum_op->result(0); | ||
sum_input = sum_op->operand_source(0); | ||
rewriter.SetInsertionPointAfter(sum_op); | ||
break; | ||
} | ||
} | ||
|
||
if (!have_sum_op) { | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个sum op的作用是什么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个sum op 是计算linear中bias的grad的,对应的是组合算子拆解前的 add_grad的算子,add_grad 被拆成了 assign op + sum op
PR Category
CINN
PR Types
Others
Description
pcard-76996
优化matmul grad的逻辑,适配编译器场景