23
23
#include " paddle/fluid/platform/for_range.h"
24
24
#include " paddle/phi/kernels/complex_kernel.h"
25
25
#include " paddle/phi/kernels/full_kernel.h"
26
+ #include " paddle/phi/kernels/funcs/common_shape.h"
26
27
#include " paddle/phi/kernels/funcs/diag_functor.h"
27
28
#include " paddle/phi/kernels/funcs/math_function.h"
28
29
#include " paddle/phi/kernels/funcs/matrix_inverse.h"
29
30
#include " paddle/phi/kernels/funcs/unsqueeze.h"
31
+ #include " paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
32
+ #include " paddle/phi/kernels/impl/determinant_kernel_impl.h"
30
33
#include " paddle/phi/kernels/math_kernel.h"
31
34
#include " paddle/phi/kernels/matmul_kernel.h"
32
35
#include " paddle/phi/kernels/transpose_kernel.h"
@@ -40,232 +43,6 @@ T sign(T val) {
40
43
return static_cast <T>(T (0 ) < val) - (val < T (0 ));
41
44
}
42
45
43
- template <typename T>
44
- class EigenMatrix {};
45
-
46
- template <>
47
- class EigenMatrix <float > {
48
- public:
49
- using MatrixType = Eigen::MatrixXf;
50
- };
51
-
52
- template <>
53
- class EigenMatrix <double > {
54
- public:
55
- using MatrixType = Eigen::MatrixXd;
56
- };
57
-
58
- inline int64_t GetBatchCount (const framework::DDim dims) {
59
- int64_t batch_count = 1 ;
60
- auto dim_size = dims.size ();
61
- PADDLE_ENFORCE_GE (
62
- dim_size, 2 ,
63
- platform::errors::InvalidArgument (
64
- " the input matrix dimension size should greater than 2." ));
65
-
66
- // Cumulative multiplying each dimension until the last 2 to get the batch
67
- // count,
68
- // for example a tensor with shape [3,3,3,3], the batch count of matrices is
69
- // 9.
70
- for (int64_t i = 0 ; i < dims.size () - 2 ; i++) {
71
- batch_count *= dims[i];
72
- }
73
-
74
- return batch_count;
75
- }
76
-
77
- template <typename T>
78
- struct DeterminantFunctor {
79
- void operator ()(const Tensor& input, const framework::ExecutionContext ctx,
80
- int64_t rank, int64_t batch_count, Tensor* output) {
81
- std::vector<T> input_vec;
82
- std::vector<T> output_vec;
83
- framework::TensorToVector (input, ctx.device_context (), &input_vec);
84
- for (int64_t i = 0 ; i < batch_count; ++i) { // maybe can be parallel
85
- auto begin_iter = input_vec.begin () + i * rank * rank;
86
- auto end_iter = input_vec.begin () + (i + 1 ) * rank * rank;
87
- std::vector<T> sub_vec (begin_iter,
88
- end_iter); // get every square matrix data
89
- typename EigenMatrix<T>::MatrixType matrix (rank, rank);
90
- for (int64_t i = 0 ; i < rank; ++i) {
91
- for (int64_t j = 0 ; j < rank; ++j) {
92
- matrix (i, j) = sub_vec[rank * i + j];
93
- }
94
- }
95
- output_vec.push_back (matrix.determinant ());
96
- }
97
- framework::TensorFromVector (output_vec, output);
98
- }
99
- };
100
- template <typename DeviceContext, typename T>
101
- class DeterminantKernel : public framework ::OpKernel<T> {
102
- public:
103
- void Compute (const framework::ExecutionContext& context) const override {
104
- auto * input = context.Input <framework::Tensor>(" Input" );
105
- auto input_dim = vectorize (input->dims ());
106
- auto input_dim_size = input_dim.size ();
107
- auto * output = context.Output <framework::Tensor>(" Out" );
108
-
109
- auto batch_count = GetBatchCount (input->dims ());
110
- VLOG (2 ) << " input dim:" << input->dims ();
111
- PADDLE_ENFORCE_GE (
112
- input_dim_size, 2 ,
113
- platform::errors::InvalidArgument (
114
- " the input matrix dimension size should greater than 2." ));
115
- PADDLE_ENFORCE_EQ (input_dim[input_dim_size - 1 ],
116
- input_dim[input_dim_size - 2 ],
117
- platform::errors::InvalidArgument (
118
- " the input matrix should be square matrix." ));
119
- auto rank = input_dim[input_dim_size - 1 ]; // square matrix length
120
- DeterminantFunctor<T>()(*input, context, rank, batch_count, output);
121
- auto output_dims = phi::slice_ddim (input->dims (), 0 , input_dim_size - 2 );
122
- if (input_dim_size > 2 ) {
123
- output->Resize (output_dims);
124
- } else {
125
- // when input is a two-dimension matrix, The det value is a number.
126
- output->Resize ({1 });
127
- }
128
- VLOG (2 ) << " output dim:" << output->dims ();
129
- }
130
- };
131
-
132
- template <typename T>
133
- struct FoundZeroFunctor {
134
- FoundZeroFunctor (const T* x, int64_t numel, bool * res)
135
- : x_(x), numel_(numel), res_(res) {}
136
- HOSTDEVICE void operator ()(size_t idx) const {
137
- if (*res_ || idx >= static_cast <size_t >(numel_)) {
138
- // founded zero number
139
- return ;
140
- }
141
- *res_ = (x_[idx] == static_cast <T>(0 ));
142
- }
143
- const T* x_;
144
- int64_t numel_;
145
- bool * res_;
146
- };
147
-
148
- template <typename DeviceContext, typename T>
149
- inline bool CheckMatrixInvertible (const framework::ExecutionContext& ctx,
150
- const framework::Tensor* det) {
151
- auto & dev_ctx = ctx.template device_context <DeviceContext>();
152
- auto numel = det->numel ();
153
-
154
- framework::Tensor dev_tensor;
155
- auto * data = dev_tensor.mutable_data <bool >({1 }, ctx.GetPlace ());
156
-
157
- // set false
158
- phi::funcs::SetConstant<DeviceContext, bool > zero;
159
- zero (dev_ctx, &dev_tensor, false );
160
-
161
- // find whether zero
162
- platform::ForRange<DeviceContext> for_range (dev_ctx, numel);
163
- FoundZeroFunctor<T> functor (det->data <T>(), numel, data);
164
- for_range (functor);
165
-
166
- // copy to host
167
- dev_ctx.Wait ();
168
- framework::Tensor cpu_tensor;
169
- framework::TensorCopy (dev_tensor, platform::CPUPlace (), &cpu_tensor);
170
-
171
- // if founded zero, the matrix is not invertible
172
- // else the matrix is invertible
173
- auto * res = cpu_tensor.data <bool >();
174
- return !(*res);
175
- }
176
-
177
- template <typename DeviceContext, typename T>
178
- class DeterminantGradKernel : public framework ::OpKernel<T> {
179
- public:
180
- void Compute (const framework::ExecutionContext& context) const override {
181
- auto & orig_dev_ctx = context.template device_context <DeviceContext>();
182
- const auto * input = context.Input <framework::Tensor>(" Input" );
183
- const auto * det = context.Input <framework::Tensor>(" Out" );
184
- const auto * grad =
185
- context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
186
- auto * ddet =
187
- context.Output <framework::Tensor>(framework::GradVarName (" Input" ));
188
-
189
- auto input_dims_size = input->dims ().size ();
190
- if (input_dims_size > 2 ) {
191
- PADDLE_ENFORCE_EQ (
192
- grad->dims ().size () + 2 , input_dims_size,
193
- platform::errors::InvalidArgument (
194
- " The grad tensor of det dims size should 2 less than"
195
- " input tensor's, but here differ %d" ,
196
- input_dims_size - grad->dims ().size ()));
197
- } else if (input_dims_size == 2 ) {
198
- // input dims size 2 and grad dims size 1 is possible
199
- PADDLE_ENFORCE_EQ (
200
- grad->dims ().size (), 1 ,
201
- platform::errors::InvalidArgument (
202
- " The grad tensor of det dims size should 2 less than"
203
- " input tensor's, but here differ %d" ,
204
- input_dims_size - grad->dims ().size ()));
205
- } else {
206
- // checked in forward, pass
207
- }
208
-
209
- auto & dev_ctx = static_cast <
210
- const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
211
- orig_dev_ctx);
212
-
213
- // Check Whether the matrix is invertible
214
- // (matrix A not invertible) == (det(A)=0)
215
- if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) {
216
- // The matrix is not invertible
217
- VLOG (3 ) << " The input matrix not invertible!" ;
218
- ddet->Resize (input->dims ());
219
- phi::Full<T>(dev_ctx, phi::vectorize (input->dims ()), static_cast <T>(0 .0f ),
220
- ddet);
221
- return ;
222
- }
223
-
224
- // The matrix is invertible
225
- // let |A| = Determinant(A)
226
- // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
227
- // we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
228
- // -1)
229
-
230
- // First: inverse(A)
231
- framework::Tensor inverse_A;
232
- // A must be square matrices!
233
- inverse_A.Resize (input->dims ());
234
- inverse_A.mutable_data <T>(context.GetPlace ());
235
-
236
- phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
237
- mat_inv (orig_dev_ctx, *input, &inverse_A);
238
-
239
- VLOG (3 ) << " inverse(A) dims: " << inverse_A.dims ();
240
-
241
- // Second: inverse(A).transpose(-2, -1)
242
- framework::Tensor transpose_inverse_A =
243
- phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
244
-
245
- VLOG (3 ) << " (dA * |A|).transpose(-2, -1) dims: "
246
- << transpose_inverse_A.dims ();
247
-
248
- // Third: dA * |A|
249
- auto mul_dA_detA = phi::Multiply<T>(dev_ctx, *grad, *det);
250
- VLOG (3 ) << " dA * |A| dims: " << mul_dA_detA.dims ();
251
-
252
- // Fourth: unsqueeze(dA * |A|, [-1, -2])
253
- auto unsqueeze1 = phi::funcs::Unsqueeze (mul_dA_detA, -1 );
254
- auto unsqueeze2 = phi::funcs::Unsqueeze (unsqueeze1, -2 );
255
- VLOG (3 ) << " unsqueezed(dA * |A|) dims: " << unsqueeze2.dims ();
256
-
257
- // Finally: unsqueeze(dA * |A|) * inverse(A)
258
- auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
259
-
260
- VLOG (3 ) << " unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims ();
261
-
262
- framework::TensorCopy (res, context.GetPlace (), ddet);
263
-
264
- ddet->Resize (input->dims ());
265
- VLOG (3 ) << " d|A| dims: " << ddet->dims ();
266
- }
267
- };
268
-
269
46
template <typename T>
270
47
struct SlogDeterminantFunctor {
271
48
void operator ()(const Tensor& input, const framework::ExecutionContext ctx,
@@ -280,7 +57,7 @@ struct SlogDeterminantFunctor {
280
57
auto end_iter = input_vec.begin () + (i + 1 ) * rank * rank;
281
58
std::vector<T> sub_vec (begin_iter,
282
59
end_iter); // get every square matrix data
283
- typename EigenMatrix<T>::MatrixType matrix (rank, rank);
60
+ typename phi::detail:: EigenMatrix<T>::MatrixType matrix (rank, rank);
284
61
for (int64_t i = 0 ; i < rank; ++i) {
285
62
for (int64_t j = 0 ; j < rank; ++j) {
286
63
matrix (i, j) = sub_vec[rank * i + j];
@@ -311,7 +88,7 @@ class SlogDeterminantKernel : public framework::OpKernel<T> {
311
88
auto input_dim_size = input_dim.size ();
312
89
auto * output = context.Output <framework::Tensor>(" Out" );
313
90
314
- auto batch_count = GetBatchCount (input->dims ());
91
+ auto batch_count = phi::detail:: GetBatchCount (input->dims ());
315
92
VLOG (2 ) << " input dim:" << input->dims ();
316
93
PADDLE_ENFORCE_GE (
317
94
input_dim_size, 2 ,
@@ -370,7 +147,9 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
370
147
// (matrix A not invertible) == (absslogdet(A)=0)
371
148
auto slogdet_vec = slogdet->Split (1 , 0 );
372
149
auto absslogdet_val = slogdet_vec[0 ];
373
- if (!CheckMatrixInvertible<DeviceContext, T>(context, &absslogdet_val)) {
150
+ if (!phi::detail::CheckMatrixInvertible<
151
+ T, typename framework::ConvertToPhiContext<DeviceContext>::TYPE>(
152
+ dev_ctx, &absslogdet_val)) {
374
153
// The matrix is not invertible
375
154
VLOG (3 ) << " The input matrix not invertible!" ;
376
155
dslogdet->Resize (input->dims ());
0 commit comments