Skip to content

Commit c796972

Browse files
authored
[Infer Symbolic Shape No.157] Add symbol_infer_interface for matrix_rank_tol (#68975)
* Add symbol_infer_interface for matrix_rank_tol * Fix comparison * Refine logic * empty commit * empty commit * Fix ci * remove useless function
1 parent e949764 commit c796972

File tree

3 files changed

+60
-7
lines changed

3 files changed

+60
-7
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,12 +1333,64 @@ bool LuUnpackOpInferSymbolicShape(
13331333
return true;
13341334
}
13351335

1336-
// bool MatrixRankTolOpInferSymbolicShape(pir::Operation *op,
1337-
// pir::InferSymbolicShapeContext
1338-
// *infer_context) {
1339-
// // pass
1340-
// return true;
1341-
// }
1336+
bool MatrixRankTolOpInferSymbolicShape(
1337+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1338+
const auto &x_shape =
1339+
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
1340+
const auto &tol_shape_or_data =
1341+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
1342+
std::vector<symbol::DimExpr> tol_shape = tol_shape_or_data.shape();
1343+
size_t x_rank = x_shape.size();
1344+
PADDLE_ENFORCE_GE(x_rank,
1345+
2,
1346+
common::errors::InvalidArgument(
1347+
"The dims of input must be greater than 2"));
1348+
bool hermitian = GetBoolAttr(op, "hermitian");
1349+
if (hermitian) {
1350+
infer_context->AddEqualCstr(x_shape[x_rank - 2], x_shape[x_rank - 1]);
1351+
}
1352+
std::vector<symbol::DimExpr> x_shape_batch = [&] {
1353+
std::vector<symbol::DimExpr> x_shape_batch;
1354+
for (size_t i = 0; i < x_rank - 2; ++i) {
1355+
x_shape_batch.push_back(x_shape[i]);
1356+
}
1357+
return x_shape_batch;
1358+
}();
1359+
1360+
int diff = x_shape_batch.size() - tol_shape.size();
1361+
if (diff > 0) {
1362+
for (int i = 0; i < diff; i++) {
1363+
tol_shape.emplace(tol_shape.begin(), 1);
1364+
}
1365+
} else {
1366+
for (int i = 0; i < -diff; i++) {
1367+
x_shape_batch.emplace(x_shape_batch.begin(), 1);
1368+
}
1369+
}
1370+
1371+
const std::vector<symbol::DimExpr> shapes = [&] {
1372+
std::vector<symbol::DimExpr> shapes;
1373+
symbol::DimExprBuilder builder;
1374+
for (size_t i = 0; i < x_shape_batch.size(); i++) {
1375+
if (x_shape_batch[i] == tol_shape[i]) {
1376+
shapes.emplace_back(x_shape_batch[i]);
1377+
} else if (x_shape_batch[i] == 1) {
1378+
shapes.emplace_back(tol_shape[i]);
1379+
} else if (tol_shape[i] == 1) {
1380+
shapes.emplace_back(x_shape_batch[i]);
1381+
} else {
1382+
shapes.emplace_back(builder.Broadcast(x_shape_batch[i], tol_shape[i]));
1383+
infer_context->AddBroadcastableCstr(x_shape_batch[i], tol_shape[i]);
1384+
}
1385+
}
1386+
return shapes;
1387+
}();
1388+
1389+
symbol::ShapeOrDataDimExprs shape_data{
1390+
symbol::TensorShapeOrDataDimExprs(shapes)};
1391+
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
1392+
return true;
1393+
}
13421394

13431395
bool MaskedSelectOpInferSymbolicShape(
13441396
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
7171
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogLoss)
7272
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lstsq)
7373
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LuUnpack)
74-
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
74+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
7575
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect)
7676
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul)
7777
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixNms)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3292,6 +3292,7 @@
32923292
kernel :
32933293
func : matrix_rank_tol
32943294
traits : paddle::dialect::ForwardOnlyTrait
3295+
interfaces : paddle::dialect::InferSymbolicShapeInterface
32953296

32963297
- op : max
32973298
args : (Tensor x, IntArray axis={}, bool keepdim=false)

0 commit comments

Comments
 (0)