Skip to content

Commit c77a263

Browse files
From00Zjq9409
andauthored
Add yaml for matrix rank op (#41466)
* modify matrix_rank * add matrix_rank shape * add matrix_rank shape * Add yaml for matrix_rank OP * Add UT Co-authored-by: zhoujianqian <15205085056@163.com>
1 parent 5516f18 commit c77a263

File tree

7 files changed

+161
-2
lines changed

7 files changed

+161
-2
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
6464
}
6565
}
6666

67+
// Used in MatrixRankTolInferMeta
68+
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
69+
auto x_vec = phi::vectorize(dim_x);
70+
if (x_vec.size() == 2) {
71+
return phi::make_ddim({1});
72+
}
73+
x_vec.erase(x_vec.end() - 2, x_vec.end());
74+
return phi::make_ddim(x_vec);
75+
}
76+
6777
} // namespace detail
6878

6979
void AllValueCompareInferMeta(const MetaTensor& x,
@@ -1465,6 +1475,47 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
14651475
out->share_lod(x);
14661476
}
14671477

1478+
void MatrixRankTolInferMeta(const MetaTensor& x,
1479+
const MetaTensor& atol_tensor,
1480+
bool use_default_tol,
1481+
bool hermitian,
1482+
MetaTensor* out) {
1483+
auto dim_x = x.dims();
1484+
PADDLE_ENFORCE_GE(
1485+
dim_x.size(),
1486+
2,
1487+
phi::errors::InvalidArgument("The dims of input must be greater than 2"));
1488+
1489+
if (hermitian) {
1490+
int rows = dim_x[dim_x.size() - 2];
1491+
int cols = dim_x[dim_x.size() - 1];
1492+
PADDLE_ENFORCE_EQ(rows,
1493+
cols,
1494+
phi::errors::InvalidArgument(
1495+
"if hermitian == true, matrix should be n*n"));
1496+
}
1497+
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
1498+
auto dim_tol = atol_tensor.dims();
1499+
if (dim_x_batch == dim_tol) {
1500+
out->set_dims(dim_x_batch);
1501+
} else {
1502+
int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
1503+
int axis = std::abs(dim_x_batch.size() - dim_tol.size());
1504+
std::vector<int> x_batch_dims_array(max_dim);
1505+
std::vector<int> tol_dims_array(max_dim);
1506+
std::vector<int> out_dims_array(max_dim);
1507+
phi::funcs::GetBroadcastDimsArrays(dim_x_batch,
1508+
dim_tol,
1509+
x_batch_dims_array.data(),
1510+
tol_dims_array.data(),
1511+
out_dims_array.data(),
1512+
max_dim,
1513+
axis);
1514+
out->set_dims(phi::make_ddim(out_dims_array));
1515+
}
1516+
out->share_lod(x);
1517+
}
1518+
14681519
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
14691520
auto dim_x = x.dims();
14701521
auto dim_vec = vec.dims();

paddle/phi/infermeta/binary.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
218218
int y_num_col_dims,
219219
MetaTensor* out);
220220

221+
void MatrixRankTolInferMeta(const MetaTensor& x,
222+
const MetaTensor& atol_tensor,
223+
bool use_default_tol,
224+
bool hermitian,
225+
MetaTensor* out);
226+
221227
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);
222228

223229
void PReluInferMeta(const MetaTensor& x,

paddle/phi/infermeta/unary.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ limitations under the License. */
3131

3232
namespace phi {
3333

34+
namespace detail {
35+
// Used in MatrixRankInferMeta
36+
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
37+
auto x_vec = phi::vectorize(dim_x);
38+
if (x_vec.size() == 2) {
39+
return phi::make_ddim({1});
40+
}
41+
x_vec.erase(x_vec.end() - 2, x_vec.end());
42+
return phi::make_ddim(x_vec);
43+
}
44+
} // namespace detail
45+
3446
void ArgMinMaxInferMeta(const MetaTensor& x,
3547
int64_t axis,
3648
bool keepdims,
@@ -901,6 +913,29 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
901913
out->set_dtype(x.dtype());
902914
}
903915

916+
void MatrixRankInferMeta(const MetaTensor& x,
917+
bool use_default_tol,
918+
bool hermitian,
919+
MetaTensor* out) {
920+
auto dim_x = x.dims();
921+
PADDLE_ENFORCE_GE(
922+
dim_x.size(),
923+
2,
924+
phi::errors::InvalidArgument("The dims of input must be greater than 2"));
925+
926+
if (hermitian) {
927+
int rows = dim_x[dim_x.size() - 2];
928+
int cols = dim_x[dim_x.size() - 1];
929+
PADDLE_ENFORCE_EQ(rows,
930+
cols,
931+
phi::errors::InvalidArgument(
932+
"if hermitian == true, matrix should be n*n"));
933+
}
934+
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
935+
out->set_dims(dim_x_batch);
936+
out->share_lod(x);
937+
}
938+
904939
void MaxOutInferMeta(const MetaTensor& x,
905940
int groups,
906941
int axis,

paddle/phi/infermeta/unary.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ void LogsumexpInferMeta(const MetaTensor& input,
142142

143143
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);
144144

145+
void MatrixRankInferMeta(const MetaTensor& x,
146+
bool use_default_tol,
147+
bool hermitian,
148+
MetaTensor* out);
149+
145150
void MaxOutInferMeta(const MetaTensor& x,
146151
int groups,
147152
int axis,

python/paddle/fluid/tests/unittests/test_matrix_rank_op.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@
3030
np.random.seed(SEED)
3131

3232

33+
def matrix_rank_wraper(x, tol=None, use_default_tol=True, hermitian=False):
34+
return paddle.linalg.matrix_rank(x, tol, hermitian)
35+
36+
3337
class TestMatrixRankOP(OpTest):
3438
def setUp(self):
39+
self.python_api = matrix_rank_wraper
3540
self.op_type = "matrix_rank"
3641
self.init_data()
3742
self.inputs = {'X': self.x}
@@ -44,7 +49,7 @@ def setUp(self):
4449
self.outputs = {'Out': self.out}
4550

4651
def test_check_output(self):
47-
self.check_output()
52+
self.check_output(check_eager=True)
4853

4954
def init_data(self):
5055
self.x = np.eye(3, dtype=np.float32)
@@ -110,6 +115,28 @@ def init_data(self):
110115
self.hermitian)
111116

112117

118+
class TestMatrixRankOP6(TestMatrixRankOP):
119+
def init_data(self):
120+
self.x = np.random.rand(3, 4, 5, 6).astype(np.float32)
121+
self.tol_tensor = None
122+
self.tol = None
123+
self.use_default_tol = False
124+
self.hermitian = False
125+
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
126+
self.hermitian)
127+
128+
129+
class TestMatrixRankOP7(TestMatrixRankOP):
130+
def init_data(self):
131+
self.x = np.eye(200, dtype=np.float64)
132+
self.tol_tensor = np.random.random([200, 200]).astype(self.x.dtype)
133+
self.tol = None
134+
self.use_default_tol = True
135+
self.hermitian = True
136+
self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
137+
self.hermitian)
138+
139+
113140
class TestMatrixRankAPI(unittest.TestCase):
114141
def test_dygraph(self):
115142
paddle.disable_static()

python/paddle/tensor/linalg.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1284,8 +1284,26 @@ def matrix_rank(x, tol=None, hermitian=False, name=None):
12841284
# [1, 1, 1, 1]]
12851285
12861286
"""
1287+
if in_dygraph_mode():
1288+
if isinstance(tol, Variable):
1289+
if tol.dtype != x.dtype:
1290+
tol_tensor = cast(tol, x.dtype)
1291+
else:
1292+
tol_tensor = tol
1293+
use_default_tol = False
1294+
return _C_ops.final_state_matrix_rank_tol(
1295+
x, tol_tensor, use_default_tol, hermitian)
12871296

1288-
if paddle.in_dynamic_mode():
1297+
if tol is None:
1298+
tol_attr = 0.0
1299+
use_default_tol = True
1300+
else:
1301+
tol_attr = float(tol)
1302+
use_default_tol = False
1303+
return _C_ops.final_state_matrix_rank(x, tol_attr, use_default_tol,
1304+
hermitian)
1305+
1306+
if _in_legacy_dygraph():
12891307
if tol is None:
12901308
tol_tensor = None
12911309
tol_attr = 0.0

python/paddle/utils/code_gen/api.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,23 @@
11571157
func : matrix_power
11581158
backward : matrix_power_grad
11591159

1160+
- api : matrix_rank
1161+
args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
1162+
output : Tensor(out)
1163+
infer_meta :
1164+
func : MatrixRankInferMeta
1165+
param : [x, use_default_tol, hermitian]
1166+
kernel :
1167+
func : matrix_rank
1168+
1169+
- api : matrix_rank_tol
1170+
args : (Tensor x, Tensor atol_tensor, bool use_default_tol=true, bool hermitian=false)
1171+
output : Tensor(out)
1172+
infer_meta :
1173+
func : MatrixRankTolInferMeta
1174+
kernel :
1175+
func : matrix_rank_tol
1176+
11601177
- api : max
11611178
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
11621179
output : Tensor(out)

0 commit comments

Comments
 (0)