Skip to content

【Infer Symbolic Shape No.139,140,141,142】[BUAA] Add generate_proposals,grid_sample,gru,gru_unit op #67413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 30, 2024

Conversation

Luohongzhige
Copy link
Contributor

@Luohongzhige Luohongzhige commented Aug 14, 2024

PR Category

CINN

PR Types

Improvements

Description

#66444
添加generate_proposals,grid_sample,gru,gru_unit算子符号推导接口
gru,gru_unit有op_test但未开启check_pir
generate_proposals关闭check_symbol_infer

Copy link

paddle-bot bot commented Aug 14, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@luotao1 luotao1 added contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 labels Aug 14, 2024
// }
bool GenerateProposalsOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return GenerateProposalsV2OpInferSymbolicShape(op, infer_context);
Copy link
Contributor

@gongshaotian gongshaotian Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yaml定义里没有GenerateProposalsV2这个op,直接用op名就行

Comment on lines 1437 to 1438
// If bias is used, check its dimensions
if (op->num_operands() > 3) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为判断是否为null value

Comment on lines 1418 to 1419
auto input_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const auto &,下同

Comment on lines 1312 to 1313
std::vector<symbol::DimExpr> rpn_rois_shape = {out_unknown_1,
symbol::DimExpr(4)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是batch_size能确定吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的, @gongshaotian 说的kernel代码之后还对变量进行了一次resize,使用的参数就是运行时确定的了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel实现是可能做了点特殊处理,但至少两个out_unknown应该是同一个值

Comment on lines 1395 to 1407
if (!is_test) {
symbol::TensorShapeOrDataDimExprs batch_gate_shape(input_shape);
infer_context->SetShapeOrDataForValue(op->result(0), batch_gate_shape);

symbol::TensorShapeOrDataDimExprs batch_reset_hidden_prev_shape(
{input_shape[0], frame_size});
infer_context->SetShapeOrDataForValue(op->result(1),
batch_reset_hidden_prev_shape);

symbol::TensorShapeOrDataDimExprs batch_hidden_shape(
{input_shape[0], frame_size});
infer_context->SetShapeOrDataForValue(op->result(2), batch_hidden_shape);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非is_test直接用静态shape赋值接口上新符号吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不太明白此处需要如何修改,是不创建临时变量直接写进SetShapeOrDataForValue函数参数里边吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

每个输出都需要设置symbolic shape,非is_test模式也需要,用infer_context->SetSymbolForValueByStaticShape(xxx)

Comment on lines 1387 to 1391
} else {
infer_context->SetSymbolForValueByStaticShape(op->result(0));
infer_context->SetSymbolForValueByStaticShape(op->result(1));
infer_context->SetSymbolForValueByStaticShape(op->result(2));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

test为true时,这几个输出看起来维度分别为 input 和 hidden 的维度

Copy link
Contributor

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

恭喜解决了第一个Kernel复用Meta推导结果的算子🎉

bool GenerateProposalsV2OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
symbol::DimExpr out_unknown_1 = infer_context->GetNextSymName();
symbol::DimExpr out_unknown_2 = infer_context->GetNextSymName();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从kernel看这两个符号表示的维度是相等的,使用一个即可

// // pass
// return true;
// }
bool GridSampleOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for GridSampleOpInferSymbolicShape

return true;
}

bool GruUnitOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for GruUnitOpInferSymbolicShape

// // pass
// return true;
// }
bool GenerateProposalsOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for GenerateProposalsOp

infer_context->AddEqualCstr(bias_shape[1], frame_size * 3);
}

if (!is_test) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分逻辑修改为:
if(is_test){
// 参考kernel的 is_test 分支实现
image
}esle{
// 参考meta的 !is_test 实现
image
}

原因为Meta在kernel执行前推导了!is_test分支下部分输出的shape,GruKernel直接使用了Meta的推导结果

luotao1
luotao1 previously approved these changes Aug 28, 2024
@luotao1
Copy link
Contributor

luotao1 commented Aug 28, 2024

冲突了

@luotao1 luotao1 merged commit 3a95bcc into PaddlePaddle:develop Aug 30, 2024
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants