-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【Infer Symbolic Shape BUAA No.10 No.73 No.78】Add mode , margin_cross_entropy , box_clip. op #66730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
a319455
86191e6
90aa9a6
8a9ca8c
c25fede
1190f7b
f14d22c
7a8d74f
a559de4
43ff6e9
43d5ce5
313c60a
6380d80
a8ab8ac
1c7d7ec
f4c5a12
f086511
1c4e9aa
217614f
69c72a5
591ae2e
2ef4ce1
58f5989
1f86f70
d6b9b1d
4e8ccc0
1fecfe5
520141f
6f64112
a3d35d8
d1b6a03
569964f
cdb5a9c
c24e47a
9f19580
7609187
b5c30b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -720,6 +720,46 @@ bool MatmulOpInferSymbolicShape(pir::Operation *op, | |
| return true; | ||
| } | ||
|
|
||
| bool MarginCrossEntropyOpInferSymbolicShape( | ||
| pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
| const auto &logits_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
| const auto &labels_shape_or_data = | ||
| infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
|
|
||
| std::vector<symbol::DimExpr> logits_dims = logits_shape_or_data.shape(); | ||
| std::vector<symbol::DimExpr> labels_dims = labels_shape_or_data.shape(); | ||
|
|
||
| size_t logits_rank = logits_dims.size(); | ||
| auto axis = logits_rank - 1; | ||
|
|
||
| for (size_t i = 0; i < logits_rank; i++) { | ||
| if (i != axis) { | ||
| infer_context->AddBroadcastableCstr(logits_dims[i], labels_dims[i]); | ||
|
||
| } | ||
| } | ||
|
|
||
| const auto one = symbol::DimExpr{1}; | ||
|
|
||
| if (labels_dims.size() > 1) { | ||
| infer_context->AddEqualCstr(logits_dims[axis], one); | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里逻辑和infer meta不太符合,和之前的相等约束一并修改后再合并吧
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
|
|
||
| infer_context->SetShapeOrDataForValue( | ||
| op->result(0), | ||
| symbol::ShapeOrDataDimExprs{ | ||
| symbol::TensorShapeOrDataDimExprs(logits_dims)}); | ||
|
|
||
| logits_dims[axis] = symbol::DimExpr(1); | ||
|
|
||
| infer_context->SetShapeOrDataForValue( | ||
| op->result(1), | ||
| symbol::ShapeOrDataDimExprs{ | ||
| symbol::TensorShapeOrDataDimExprs(logits_dims)}); | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| bool MvOpInferSymbolicShape(pir::Operation *op, | ||
| pir::InferSymbolicShapeContext *infer_context) { | ||
| const auto &x_shape_or_data = | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.