Skip to content

Commit a04a6bd

Browse files
authored
[Phi] Move determinant op kernel into phi (#40539)
* add determinant phi kernel * remove original determinant op kernel * add determinant grad [hi kernel * fix determinant test failed * remove original determinant grad op kernel
1 parent 0c0acbd commit a04a6bd

13 files changed

+473
-246
lines changed

paddle/fluid/operators/determinant_op.cc

-8
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,6 @@ REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker,
168168

169169
REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp)
170170

171-
REGISTER_OP_CPU_KERNEL(determinant,
172-
ops::DeterminantKernel<plat::CPUDeviceContext, float>,
173-
ops::DeterminantKernel<plat::CPUDeviceContext, double>);
174-
175-
REGISTER_OP_CPU_KERNEL(
176-
determinant_grad, ops::DeterminantGradKernel<plat::CPUDeviceContext, float>,
177-
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>);
178-
179171
REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
180172
ops::SlogDeterminantOpMaker,
181173
ops::SlogDeterminantGradOpMaker<paddle::framework::OpDesc>,

paddle/fluid/operators/determinant_op.cu

-8
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@ limitations under the License. */
1717

1818
namespace ops = paddle::operators;
1919
namespace plat = paddle::platform;
20-
REGISTER_OP_CUDA_KERNEL(
21-
determinant, ops::DeterminantKernel<plat::CUDADeviceContext, float>,
22-
ops::DeterminantKernel<plat::CUDADeviceContext, double>);
23-
24-
REGISTER_OP_CUDA_KERNEL(
25-
determinant_grad,
26-
ops::DeterminantGradKernel<plat::CUDADeviceContext, float>,
27-
ops::DeterminantGradKernel<plat::CUDADeviceContext, double>);
2820

2921
REGISTER_OP_CUDA_KERNEL(
3022
slogdeterminant, ops::SlogDeterminantKernel<plat::CUDADeviceContext, float>,

paddle/fluid/operators/determinant_op.h

+8-229
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
#include "paddle/fluid/platform/for_range.h"
2424
#include "paddle/phi/kernels/complex_kernel.h"
2525
#include "paddle/phi/kernels/full_kernel.h"
26+
#include "paddle/phi/kernels/funcs/common_shape.h"
2627
#include "paddle/phi/kernels/funcs/diag_functor.h"
2728
#include "paddle/phi/kernels/funcs/math_function.h"
2829
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
2930
#include "paddle/phi/kernels/funcs/unsqueeze.h"
31+
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
32+
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
3033
#include "paddle/phi/kernels/math_kernel.h"
3134
#include "paddle/phi/kernels/matmul_kernel.h"
3235
#include "paddle/phi/kernels/transpose_kernel.h"
@@ -40,232 +43,6 @@ T sign(T val) {
4043
return static_cast<T>(T(0) < val) - (val < T(0));
4144
}
4245

43-
template <typename T>
44-
class EigenMatrix {};
45-
46-
template <>
47-
class EigenMatrix<float> {
48-
public:
49-
using MatrixType = Eigen::MatrixXf;
50-
};
51-
52-
template <>
53-
class EigenMatrix<double> {
54-
public:
55-
using MatrixType = Eigen::MatrixXd;
56-
};
57-
58-
inline int64_t GetBatchCount(const framework::DDim dims) {
59-
int64_t batch_count = 1;
60-
auto dim_size = dims.size();
61-
PADDLE_ENFORCE_GE(
62-
dim_size, 2,
63-
platform::errors::InvalidArgument(
64-
"the input matrix dimension size should greater than 2."));
65-
66-
// Cumulative multiplying each dimension until the last 2 to get the batch
67-
// count,
68-
// for example a tensor with shape [3,3,3,3], the batch count of matrices is
69-
// 9.
70-
for (int64_t i = 0; i < dims.size() - 2; i++) {
71-
batch_count *= dims[i];
72-
}
73-
74-
return batch_count;
75-
}
76-
77-
template <typename T>
78-
struct DeterminantFunctor {
79-
void operator()(const Tensor& input, const framework::ExecutionContext ctx,
80-
int64_t rank, int64_t batch_count, Tensor* output) {
81-
std::vector<T> input_vec;
82-
std::vector<T> output_vec;
83-
framework::TensorToVector(input, ctx.device_context(), &input_vec);
84-
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
85-
auto begin_iter = input_vec.begin() + i * rank * rank;
86-
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
87-
std::vector<T> sub_vec(begin_iter,
88-
end_iter); // get every square matrix data
89-
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
90-
for (int64_t i = 0; i < rank; ++i) {
91-
for (int64_t j = 0; j < rank; ++j) {
92-
matrix(i, j) = sub_vec[rank * i + j];
93-
}
94-
}
95-
output_vec.push_back(matrix.determinant());
96-
}
97-
framework::TensorFromVector(output_vec, output);
98-
}
99-
};
100-
template <typename DeviceContext, typename T>
101-
class DeterminantKernel : public framework::OpKernel<T> {
102-
public:
103-
void Compute(const framework::ExecutionContext& context) const override {
104-
auto* input = context.Input<framework::Tensor>("Input");
105-
auto input_dim = vectorize(input->dims());
106-
auto input_dim_size = input_dim.size();
107-
auto* output = context.Output<framework::Tensor>("Out");
108-
109-
auto batch_count = GetBatchCount(input->dims());
110-
VLOG(2) << "input dim:" << input->dims();
111-
PADDLE_ENFORCE_GE(
112-
input_dim_size, 2,
113-
platform::errors::InvalidArgument(
114-
"the input matrix dimension size should greater than 2."));
115-
PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1],
116-
input_dim[input_dim_size - 2],
117-
platform::errors::InvalidArgument(
118-
"the input matrix should be square matrix."));
119-
auto rank = input_dim[input_dim_size - 1]; // square matrix length
120-
DeterminantFunctor<T>()(*input, context, rank, batch_count, output);
121-
auto output_dims = phi::slice_ddim(input->dims(), 0, input_dim_size - 2);
122-
if (input_dim_size > 2) {
123-
output->Resize(output_dims);
124-
} else {
125-
// when input is a two-dimension matrix, The det value is a number.
126-
output->Resize({1});
127-
}
128-
VLOG(2) << "output dim:" << output->dims();
129-
}
130-
};
131-
132-
template <typename T>
133-
struct FoundZeroFunctor {
134-
FoundZeroFunctor(const T* x, int64_t numel, bool* res)
135-
: x_(x), numel_(numel), res_(res) {}
136-
HOSTDEVICE void operator()(size_t idx) const {
137-
if (*res_ || idx >= static_cast<size_t>(numel_)) {
138-
// founded zero number
139-
return;
140-
}
141-
*res_ = (x_[idx] == static_cast<T>(0));
142-
}
143-
const T* x_;
144-
int64_t numel_;
145-
bool* res_;
146-
};
147-
148-
template <typename DeviceContext, typename T>
149-
inline bool CheckMatrixInvertible(const framework::ExecutionContext& ctx,
150-
const framework::Tensor* det) {
151-
auto& dev_ctx = ctx.template device_context<DeviceContext>();
152-
auto numel = det->numel();
153-
154-
framework::Tensor dev_tensor;
155-
auto* data = dev_tensor.mutable_data<bool>({1}, ctx.GetPlace());
156-
157-
// set false
158-
phi::funcs::SetConstant<DeviceContext, bool> zero;
159-
zero(dev_ctx, &dev_tensor, false);
160-
161-
// find whether zero
162-
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
163-
FoundZeroFunctor<T> functor(det->data<T>(), numel, data);
164-
for_range(functor);
165-
166-
// copy to host
167-
dev_ctx.Wait();
168-
framework::Tensor cpu_tensor;
169-
framework::TensorCopy(dev_tensor, platform::CPUPlace(), &cpu_tensor);
170-
171-
// if founded zero, the matrix is not invertible
172-
// else the matrix is invertible
173-
auto* res = cpu_tensor.data<bool>();
174-
return !(*res);
175-
}
176-
177-
template <typename DeviceContext, typename T>
178-
class DeterminantGradKernel : public framework::OpKernel<T> {
179-
public:
180-
void Compute(const framework::ExecutionContext& context) const override {
181-
auto& orig_dev_ctx = context.template device_context<DeviceContext>();
182-
const auto* input = context.Input<framework::Tensor>("Input");
183-
const auto* det = context.Input<framework::Tensor>("Out");
184-
const auto* grad =
185-
context.Input<framework::Tensor>(framework::GradVarName("Out"));
186-
auto* ddet =
187-
context.Output<framework::Tensor>(framework::GradVarName("Input"));
188-
189-
auto input_dims_size = input->dims().size();
190-
if (input_dims_size > 2) {
191-
PADDLE_ENFORCE_EQ(
192-
grad->dims().size() + 2, input_dims_size,
193-
platform::errors::InvalidArgument(
194-
"The grad tensor of det dims size should 2 less than"
195-
" input tensor's, but here differ %d",
196-
input_dims_size - grad->dims().size()));
197-
} else if (input_dims_size == 2) {
198-
// input dims size 2 and grad dims size 1 is possible
199-
PADDLE_ENFORCE_EQ(
200-
grad->dims().size(), 1,
201-
platform::errors::InvalidArgument(
202-
"The grad tensor of det dims size should 2 less than"
203-
" input tensor's, but here differ %d",
204-
input_dims_size - grad->dims().size()));
205-
} else {
206-
// checked in forward, pass
207-
}
208-
209-
auto& dev_ctx = static_cast<
210-
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
211-
orig_dev_ctx);
212-
213-
// Check Whether the matrix is invertible
214-
// (matrix A not invertible) == (det(A)=0)
215-
if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) {
216-
// The matrix is not invertible
217-
VLOG(3) << "The input matrix not invertible!";
218-
ddet->Resize(input->dims());
219-
phi::Full<T>(dev_ctx, phi::vectorize(input->dims()), static_cast<T>(0.0f),
220-
ddet);
221-
return;
222-
}
223-
224-
// The matrix is invertible
225-
// let |A| = Determinant(A)
226-
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
227-
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
228-
// -1)
229-
230-
// First: inverse(A)
231-
framework::Tensor inverse_A;
232-
// A must be square matrices!
233-
inverse_A.Resize(input->dims());
234-
inverse_A.mutable_data<T>(context.GetPlace());
235-
236-
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
237-
mat_inv(orig_dev_ctx, *input, &inverse_A);
238-
239-
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
240-
241-
// Second: inverse(A).transpose(-2, -1)
242-
framework::Tensor transpose_inverse_A =
243-
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
244-
245-
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
246-
<< transpose_inverse_A.dims();
247-
248-
// Third: dA * |A|
249-
auto mul_dA_detA = phi::Multiply<T>(dev_ctx, *grad, *det);
250-
VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims();
251-
252-
// Fourth: unsqueeze(dA * |A|, [-1, -2])
253-
auto unsqueeze1 = phi::funcs::Unsqueeze(mul_dA_detA, -1);
254-
auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
255-
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
256-
257-
// Finally: unsqueeze(dA * |A|) * inverse(A)
258-
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
259-
260-
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
261-
262-
framework::TensorCopy(res, context.GetPlace(), ddet);
263-
264-
ddet->Resize(input->dims());
265-
VLOG(3) << "d|A| dims: " << ddet->dims();
266-
}
267-
};
268-
26946
template <typename T>
27047
struct SlogDeterminantFunctor {
27148
void operator()(const Tensor& input, const framework::ExecutionContext ctx,
@@ -280,7 +57,7 @@ struct SlogDeterminantFunctor {
28057
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
28158
std::vector<T> sub_vec(begin_iter,
28259
end_iter); // get every square matrix data
283-
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
60+
typename phi::detail::EigenMatrix<T>::MatrixType matrix(rank, rank);
28461
for (int64_t i = 0; i < rank; ++i) {
28562
for (int64_t j = 0; j < rank; ++j) {
28663
matrix(i, j) = sub_vec[rank * i + j];
@@ -311,7 +88,7 @@ class SlogDeterminantKernel : public framework::OpKernel<T> {
31188
auto input_dim_size = input_dim.size();
31289
auto* output = context.Output<framework::Tensor>("Out");
31390

314-
auto batch_count = GetBatchCount(input->dims());
91+
auto batch_count = phi::detail::GetBatchCount(input->dims());
31592
VLOG(2) << "input dim:" << input->dims();
31693
PADDLE_ENFORCE_GE(
31794
input_dim_size, 2,
@@ -370,7 +147,9 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
370147
// (matrix A not invertible) == (absslogdet(A)=0)
371148
auto slogdet_vec = slogdet->Split(1, 0);
372149
auto absslogdet_val = slogdet_vec[0];
373-
if (!CheckMatrixInvertible<DeviceContext, T>(context, &absslogdet_val)) {
150+
if (!phi::detail::CheckMatrixInvertible<
151+
T, typename framework::ConvertToPhiContext<DeviceContext>::TYPE>(
152+
dev_ctx, &absslogdet_val)) {
374153
// The matrix is not invertible
375154
VLOG(3) << "The input matrix not invertible!";
376155
dslogdet->Resize(input->dims());

paddle/phi/kernels/CMakeLists.txt

+6-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
2727
# Some kernels depend on some targets that are not commonly used.
2828
# These targets are not suitable for common dependencies.
2929
# In this case, you need to manually generate them here.
30-
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel math_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel triangular_solve_grad_kernel)
30+
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel math_kernel
31+
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
32+
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
33+
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
34+
triangular_solve_grad_kernel determinant_grad_kernel)
3135
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
3236
kernel_library(gumbel_softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
3337
kernel_library(gumbel_softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
@@ -46,6 +50,7 @@ kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
4650
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
4751
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
4852
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
53+
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
4954

5055
# 4. auto parse and build kernel targets by cmake
5156
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/determinant_grad_kernel.h"
16+
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
19+
20+
PD_REGISTER_KERNEL(determinant_grad,
21+
CPU,
22+
ALL_LAYOUT,
23+
phi::DeterminantGradKernel,
24+
float,
25+
double) {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/determinant_kernel.h"
16+
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
19+
20+
PD_REGISTER_KERNEL(
21+
determinant, CPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {}

0 commit comments

Comments
 (0)