@@ -1333,12 +1333,64 @@ bool LuUnpackOpInferSymbolicShape(
1333
1333
return true ;
1334
1334
}
1335
1335
1336
- // bool MatrixRankTolOpInferSymbolicShape(pir::Operation *op,
1337
- // pir::InferSymbolicShapeContext
1338
- // *infer_context) {
1339
- // // pass
1340
- // return true;
1341
- // }
1336
+ bool MatrixRankTolOpInferSymbolicShape (
1337
+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1338
+ const auto &x_shape =
1339
+ infer_context->GetShapeOrDataForValue (op->operand_source (0 )).shape ();
1340
+ const auto &tol_shape_or_data =
1341
+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
1342
+ std::vector<symbol::DimExpr> tol_shape = tol_shape_or_data.shape ();
1343
+ size_t x_rank = x_shape.size ();
1344
+ PADDLE_ENFORCE_GE (x_rank,
1345
+ 2 ,
1346
+ common::errors::InvalidArgument (
1347
+ " The dims of input must be greater than 2" ));
1348
+ bool hermitian = GetBoolAttr (op, " hermitian" );
1349
+ if (hermitian) {
1350
+ infer_context->AddEqualCstr (x_shape[x_rank - 2 ], x_shape[x_rank - 1 ]);
1351
+ }
1352
+ std::vector<symbol::DimExpr> x_shape_batch = [&] {
1353
+ std::vector<symbol::DimExpr> x_shape_batch;
1354
+ for (size_t i = 0 ; i < x_rank - 2 ; ++i) {
1355
+ x_shape_batch.push_back (x_shape[i]);
1356
+ }
1357
+ return x_shape_batch;
1358
+ }();
1359
+
1360
+ int diff = x_shape_batch.size () - tol_shape.size ();
1361
+ if (diff > 0 ) {
1362
+ for (int i = 0 ; i < diff; i++) {
1363
+ tol_shape.emplace (tol_shape.begin (), 1 );
1364
+ }
1365
+ } else {
1366
+ for (int i = 0 ; i < -diff; i++) {
1367
+ x_shape_batch.emplace (x_shape_batch.begin (), 1 );
1368
+ }
1369
+ }
1370
+
1371
+ const std::vector<symbol::DimExpr> shapes = [&] {
1372
+ std::vector<symbol::DimExpr> shapes;
1373
+ symbol::DimExprBuilder builder;
1374
+ for (size_t i = 0 ; i < x_shape_batch.size (); i++) {
1375
+ if (x_shape_batch[i] == tol_shape[i]) {
1376
+ shapes.emplace_back (x_shape_batch[i]);
1377
+ } else if (x_shape_batch[i] == 1 ) {
1378
+ shapes.emplace_back (tol_shape[i]);
1379
+ } else if (tol_shape[i] == 1 ) {
1380
+ shapes.emplace_back (x_shape_batch[i]);
1381
+ } else {
1382
+ shapes.emplace_back (builder.Broadcast (x_shape_batch[i], tol_shape[i]));
1383
+ infer_context->AddBroadcastableCstr (x_shape_batch[i], tol_shape[i]);
1384
+ }
1385
+ }
1386
+ return shapes;
1387
+ }();
1388
+
1389
+ symbol::ShapeOrDataDimExprs shape_data{
1390
+ symbol::TensorShapeOrDataDimExprs (shapes)};
1391
+ infer_context->SetShapeOrDataForValue (op->result (0 ), shape_data);
1392
+ return true ;
1393
+ }
1342
1394
1343
1395
bool MaskedSelectOpInferSymbolicShape (
1344
1396
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
0 commit comments