Skip to content

Commit b1fe09d

Browse files
0x45fAnnaTrainingG
authored andcommitted
Add matrix_rank Op and it's GPU and CPU kernel (PaddlePaddle#34823)
* init matrix_rank op, add matrix_rank CPU code and test * add GPU kernel, remove svd_eigen.h * add CPU kernel when tol is tensor * add cpu and gpu code when tol is tensor * fix CI-ROCM error * add matrix_rank API describe, fix PR-CI-Py3 error * fix PR-CI-Windows error, add matrix_rank API test * delete useless comments * fix review * add my code in svd_helper.h * update doc commets * remove spaces
1 parent e3498f9 commit b1fe09d

File tree

11 files changed

+965
-3
lines changed

11 files changed

+965
-3
lines changed

cmake/operators.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ function(op_library TARGET)
183183
list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
184184
list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
185185
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
186+
list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
186187
list(REMOVE_ITEM hip_srcs "svd_op.cu")
187188
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
188189
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/matrix_rank_op.h"
16+
#include <memory>
17+
#include <string>
18+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
19+
#include "paddle/fluid/operators/svd_helper.h"
20+
21+
#ifdef PADDLE_WITH_MKLDNN
22+
#include "paddle/fluid/platform/mkldnn_helper.h"
23+
#endif
24+
25+
namespace paddle {
26+
namespace operators {
27+
using DDim = framework::DDim;
28+
29+
namespace detail {
30+
static DDim GetInputBatchDim(const DDim& dim_x) {
31+
auto x_vec = framework::vectorize(dim_x);
32+
if (x_vec.size() == 2) {
33+
return framework::make_ddim({1});
34+
}
35+
x_vec.erase(x_vec.end() - 2, x_vec.end());
36+
return framework::make_ddim(x_vec);
37+
}
38+
} // namespace detail
39+
40+
class MatrixRankeOp : public framework::OperatorWithKernel {
41+
public:
42+
using framework::OperatorWithKernel::OperatorWithKernel;
43+
44+
void InferShape(framework::InferShapeContext* ctx) const override {
45+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "MatrixRank");
46+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixRank");
47+
auto dim_x = ctx->GetInputDim("X");
48+
PADDLE_ENFORCE_GE(dim_x.size(), 2,
49+
platform::errors::InvalidArgument(
50+
"The dims of input must be greater than 2"));
51+
52+
bool hermitian = ctx->Attrs().Get<bool>("hermitian");
53+
if (hermitian) {
54+
int rows = dim_x[dim_x.size() - 2];
55+
int cols = dim_x[dim_x.size() - 1];
56+
PADDLE_ENFORCE_EQ(rows, cols,
57+
platform::errors::InvalidArgument(
58+
"if hermitian == true, matrix should be n*n"));
59+
}
60+
61+
DDim dim_x_batch = detail::GetInputBatchDim(dim_x);
62+
if (ctx->Attrs().Get<bool>(
63+
"use_default_tol")) { // user not input TolTensor and tol
64+
ctx->SetOutputDim("Out", dim_x_batch);
65+
} else if (ctx->HasInput("TolTensor")) {
66+
auto dim_tol = ctx->GetInputDim("TolTensor");
67+
if (dim_x_batch == dim_tol) {
68+
ctx->SetOutputDim("Out", dim_x_batch);
69+
} else {
70+
int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
71+
int axis = std::abs(dim_x_batch.size() - dim_tol.size());
72+
std::vector<int> x_batch_dims_array(max_dim);
73+
std::vector<int> tol_dims_array(max_dim);
74+
std::vector<int> out_dims_array(max_dim);
75+
GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(),
76+
tol_dims_array.data(), out_dims_array.data(),
77+
max_dim, axis);
78+
for (auto& it : out_dims_array) {
79+
VLOG(3) << "out dims: " << it;
80+
}
81+
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array));
82+
}
83+
} else {
84+
ctx->SetOutputDim("Out", dim_x_batch);
85+
}
86+
ctx->ShareLoD("X", /*->*/ "Out");
87+
}
88+
89+
protected:
90+
framework::OpKernelType GetExpectedKernelType(
91+
const framework::ExecutionContext& ctx) const override {
92+
framework::LibraryType library{framework::LibraryType::kPlain};
93+
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
94+
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
95+
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
96+
}
97+
};
98+
99+
class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker {
100+
public:
101+
void Make() override {
102+
AddInput("X", "(Tensor), The input tensor of matrix_rank op.");
103+
AddInput("TolTensor", "(optional) Tol tensor, shape is same as X batch.")
104+
.AsDispensable();
105+
AddOutput("Out", "(Tensor), The output tensor of matrix_rank op.");
106+
AddAttr<float>("tol", "(float, optional). tol").SetDefault(0.0f);
107+
AddAttr<bool>("use_default_tol",
108+
"represent whether user input TolTensor/tol, if input "
109+
"TolTensor/tol use_default_tol=true, otherwise "
110+
"use_default_tol=false")
111+
.SetDefault(true);
112+
AddAttr<bool>("hermitian", "(bool, optional). whether is hermitian matrix")
113+
.SetDefault(false);
114+
AddComment(R"DOC(MatrixRank Operator.
115+
This operator is used to perform MatrixRank operation for batched matrics.
116+
$$out = matrix_rank(X, tol, hermitian)$$
117+
)DOC");
118+
}
119+
};
120+
121+
template <typename T>
122+
void BatchEigenvalues(const T* x_data, T* eigenvalues_data, int batches,
123+
int rows, int cols, int k) {
124+
// Eigen::Matrix API need non-const pointer.
125+
T* input = const_cast<T*>(x_data);
126+
int stride = rows * cols;
127+
for (int i = 0; i < batches; i++) {
128+
auto m = Eigen::Map<
129+
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
130+
input + i * stride, rows, rows);
131+
Eigen::SelfAdjointEigenSolver<
132+
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
133+
eigen_solver(m);
134+
auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs();
135+
for (int j = 0; j < k; j++) {
136+
*(eigenvalues_data + i * k + j) = eigenvalues[j];
137+
}
138+
}
139+
}
140+
141+
template <typename T>
142+
void BatchSVD(const T* x_data, T* eigenvalues_data, int batches, int rows,
143+
int cols, int k) {
144+
// Eigen::Matrix API need non-const pointer.
145+
T* input = const_cast<T*>(x_data);
146+
int stride = rows * cols;
147+
Eigen::BDCSVD<
148+
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
149+
svd;
150+
for (int i = 0; i < batches; i++) {
151+
auto m = Eigen::Map<
152+
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
153+
input + i * stride, rows, cols);
154+
svd.compute(m);
155+
auto res_s = svd.singularValues();
156+
for (int j = 0; j < k; j++) {
157+
eigenvalues_data[i * k + j] = res_s[j];
158+
}
159+
}
160+
}
161+
162+
template <typename T>
163+
class MatrixRankCPUKernel : public framework::OpKernel<T> {
164+
public:
165+
void Compute(const framework::ExecutionContext& context) const override {
166+
const Tensor* x = context.Input<Tensor>("X");
167+
auto* x_data = x->data<T>();
168+
auto* out = context.Output<Tensor>("Out");
169+
out->mutable_data<int64_t>(context.GetPlace());
170+
bool hermitian = context.Attr<bool>("hermitian");
171+
172+
auto dim_x = x->dims();
173+
auto dim_out = out->dims();
174+
int rows = dim_x[dim_x.size() - 2];
175+
int cols = dim_x[dim_x.size() - 1];
176+
int k = std::min(rows, cols);
177+
auto numel = x->numel();
178+
int batches = numel / (rows * cols);
179+
180+
bool use_default_tol = context.Attr<bool>("use_default_tol");
181+
const Tensor* atol_tensor = nullptr;
182+
Tensor temp_tensor;
183+
T rtol_T = 0;
184+
if (use_default_tol) {
185+
framework::TensorFromVector<T>(std::vector<T>{0},
186+
context.device_context(), &temp_tensor);
187+
atol_tensor = &temp_tensor;
188+
rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
189+
} else if (context.HasInput("TolTensor")) {
190+
atol_tensor = context.Input<Tensor>("TolTensor");
191+
} else {
192+
framework::TensorFromVector<T>(std::vector<T>{context.Attr<float>("tol")},
193+
context.device_context(), &temp_tensor);
194+
atol_tensor = &temp_tensor;
195+
}
196+
197+
Tensor eigenvalue_tensor;
198+
auto* eigenvalue_data = eigenvalue_tensor.mutable_data<T>(
199+
detail::GetEigenvalueDim(dim_x, k), context.GetPlace());
200+
if (hermitian) {
201+
BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k);
202+
} else {
203+
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k);
204+
}
205+
206+
auto dito_T =
207+
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>(
208+
context);
209+
std::vector<int> max_eigenvalue_shape = framework::vectorize<int>(
210+
detail::RemoveLastDim(eigenvalue_tensor.dims()));
211+
Tensor max_eigenvalue_tensor =
212+
dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape);
213+
214+
Tensor temp_rtol_tensor;
215+
framework::TensorFromVector<T>(std::vector<T>{rtol_T}, &temp_rtol_tensor);
216+
Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor);
217+
Tensor tol_tensor;
218+
tol_tensor.mutable_data<T>(dim_out, context.GetPlace());
219+
ElementwiseComputeEx<GreaterElementFunctor<T>, platform::CPUDeviceContext,
220+
T, T>(context, atol_tensor, &rtol_tensor, -1,
221+
GreaterElementFunctor<T>(), &tol_tensor);
222+
223+
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
224+
225+
Tensor compare_result;
226+
compare_result.mutable_data<int>(detail::NewAxisDim(dim_out, k),
227+
context.GetPlace());
228+
229+
int axis = -1;
230+
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
231+
ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CPUDeviceContext, T,
232+
int>(context, &eigenvalue_tensor, &tol_tensor, axis,
233+
GreaterThanFunctor<T>(), &compare_result);
234+
} else {
235+
ElementwiseComputeEx<LessThanFunctor<T>, platform::CPUDeviceContext, T,
236+
int>(context, &eigenvalue_tensor, &tol_tensor, axis,
237+
LessThanFunctor<T>(), &compare_result);
238+
}
239+
auto dito_int =
240+
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext,
241+
int64_t>(context);
242+
std::vector<int> result_shape = framework::vectorize<int>(dim_out);
243+
Tensor result = dito_int.ReduceSum(compare_result, result_shape);
244+
out->ShareDataWith(result);
245+
}
246+
};
247+
248+
} // namespace operators
249+
} // namespace paddle
250+
251+
namespace ops = paddle::operators;
252+
253+
REGISTER_OPERATOR(matrix_rank, ops::MatrixRankeOp, ops::MatrixRankeOpMaker);
254+
255+
REGISTER_OP_CPU_KERNEL(matrix_rank, ops::MatrixRankCPUKernel<float>,
256+
ops::MatrixRankCPUKernel<double>);

0 commit comments

Comments
 (0)