@@ -175,128 +175,118 @@ void MatrixRankTolKernel(const Context& dev_ctx,
175
175
template <typename T, typename Context>
176
176
void MatrixRankAtolRtolKernel (const Context& dev_ctx,
177
177
const DenseTensor& x,
178
- const paddle::optional<DenseTensor>& tol,
179
- const paddle::optional<DenseTensor>& atol,
178
+ const DenseTensor& atol,
180
179
const paddle::optional<DenseTensor>& rtol,
181
- bool use_default_tol,
182
180
bool hermitian,
183
181
DenseTensor* out) {
184
- if (tol) {
185
- MatrixRankTolKernel<T, Context>(
186
- dev_ctx, x, *tol, use_default_tol, hermitian, out);
182
+ dev_ctx.template Alloc <int64_t >(out);
183
+ auto dim_x = x.dims ();
184
+ auto dim_out = out->dims ();
185
+ int rows = static_cast <int >(dim_x[dim_x.size () - 2 ]);
186
+ int cols = static_cast <int >(dim_x[dim_x.size () - 1 ]);
187
+ int k = std::min (rows, cols);
188
+ int batches = static_cast <int >(x.numel () / (rows * cols));
189
+
190
+ DenseTensor eigenvalue_tensor;
191
+ eigenvalue_tensor.Resize (detail::GetEigenvalueDim (dim_x, k));
192
+ auto * eigenvalue_data = dev_ctx.template Alloc <T>(&eigenvalue_tensor);
193
+
194
+ if (hermitian) {
195
+ phi::funcs::MatrixEighFunctor<Context, T> functor;
196
+ functor (dev_ctx, x, &eigenvalue_tensor, nullptr , true , false );
197
+ phi::AbsKernel<T, Context>(dev_ctx, eigenvalue_tensor, &eigenvalue_tensor);
198
+ } else {
199
+ DenseTensor trans_x = phi::TransposeLast2Dim<T>(dev_ctx, x);
200
+ auto * x_data = trans_x.data <T>();
201
+ BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols);
202
+ }
203
+
204
+ DenseTensor max_eigenvalue_tensor;
205
+ max_eigenvalue_tensor.Resize (detail::RemoveLastDim (eigenvalue_tensor.dims ()));
206
+ dev_ctx.template Alloc <T>(&max_eigenvalue_tensor);
207
+ phi::MaxKernel<T, Context>(dev_ctx,
208
+ eigenvalue_tensor,
209
+ phi::IntArray ({-1 }),
210
+ false ,
211
+ &max_eigenvalue_tensor);
212
+
213
+ DenseTensor tol_tensor;
214
+ tol_tensor.Resize (dim_out);
215
+ dev_ctx.template Alloc <T>(&tol_tensor);
216
+
217
+ if (rtol) {
218
+ DenseTensor tmp_rtol_tensor;
219
+ tmp_rtol_tensor = phi::Multiply<T>(dev_ctx, *rtol, max_eigenvalue_tensor);
220
+ funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
221
+ dev_ctx,
222
+ atol,
223
+ tmp_rtol_tensor,
224
+ GreaterElementFunctor<T>(),
225
+ &tol_tensor);
226
+ } else {
227
+ // when `rtol` is specified to be None in py api
228
+ // use rtol=eps*max(m, n) only if `atol` is passed with value 0.0, else use
229
+ // rtol=0.0
230
+ DenseTensor zero_tensor;
231
+ zero_tensor = phi::FullLike<T, Context>(dev_ctx, atol, static_cast <T>(0.0 ));
232
+
233
+ T rtol_T = std::numeric_limits<T>::epsilon () * std::max (rows, cols);
234
+ DenseTensor default_rtol_tensor;
235
+ default_rtol_tensor =
236
+ phi::Full<T, Context>(dev_ctx, {1 }, static_cast <T>(rtol_T));
237
+ default_rtol_tensor =
238
+ phi::Multiply<T>(dev_ctx, default_rtol_tensor, max_eigenvalue_tensor);
239
+
240
+ DenseTensor atol_compare_result;
241
+ atol_compare_result.Resize (atol.dims ());
242
+ phi::EqualKernel<T, Context>(
243
+ dev_ctx, atol, zero_tensor, &atol_compare_result);
244
+
245
+ DenseTensor selected_rtol_tensor;
246
+ selected_rtol_tensor.Resize (atol.dims ());
247
+ phi::WhereKernel<T, Context>(dev_ctx,
248
+ atol_compare_result,
249
+ default_rtol_tensor,
250
+ zero_tensor,
251
+ &selected_rtol_tensor);
252
+ funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
253
+ dev_ctx,
254
+ atol,
255
+ selected_rtol_tensor,
256
+ GreaterElementFunctor<T>(),
257
+ &tol_tensor);
258
+ }
259
+
260
+ tol_tensor.Resize (detail::NewAxisDim (tol_tensor.dims (), 1 ));
261
+
262
+ DenseTensor compare_result;
263
+ compare_result.Resize (detail::NewAxisDim (dim_out, k));
264
+ dev_ctx.template Alloc <int64_t >(&compare_result);
265
+ int axis = -1 ;
266
+ if (eigenvalue_tensor.dims ().size () >= tol_tensor.dims ().size ()) {
267
+ funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t >, T, int >(
268
+ dev_ctx,
269
+ eigenvalue_tensor,
270
+ tol_tensor,
271
+ funcs::GreaterThanFunctor<T, int64_t >(),
272
+ &compare_result,
273
+ axis);
187
274
} else {
188
- DenseTensor atol_tensor = *atol;
189
- DenseTensor rtol_tensor = *rtol;
190
- dev_ctx.template Alloc <int64_t >(out);
191
- auto dim_x = x.dims ();
192
- auto dim_out = out->dims ();
193
- int rows = static_cast <int >(dim_x[dim_x.size () - 2 ]);
194
- int cols = static_cast <int >(dim_x[dim_x.size () - 1 ]);
195
- int k = std::min (rows, cols);
196
- int batches = static_cast <int >(x.numel () / (rows * cols));
197
-
198
- DenseTensor eigenvalue_tensor;
199
- eigenvalue_tensor.Resize (detail::GetEigenvalueDim (dim_x, k));
200
- auto * eigenvalue_data = dev_ctx.template Alloc <T>(&eigenvalue_tensor);
201
-
202
- if (hermitian) {
203
- phi::funcs::MatrixEighFunctor<Context, T> functor;
204
- functor (dev_ctx, x, &eigenvalue_tensor, nullptr , true , false );
205
- phi::AbsKernel<T, Context>(
206
- dev_ctx, eigenvalue_tensor, &eigenvalue_tensor);
207
- } else {
208
- DenseTensor trans_x = phi::TransposeLast2Dim<T>(dev_ctx, x);
209
- auto * x_data = trans_x.data <T>();
210
- BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols);
211
- }
212
-
213
- DenseTensor max_eigenvalue_tensor;
214
- max_eigenvalue_tensor.Resize (
215
- detail::RemoveLastDim (eigenvalue_tensor.dims ()));
216
- dev_ctx.template Alloc <T>(&max_eigenvalue_tensor);
217
- phi::MaxKernel<T, Context>(dev_ctx,
218
- eigenvalue_tensor,
219
- phi::IntArray ({-1 }),
220
- false ,
221
- &max_eigenvalue_tensor);
222
-
223
- DenseTensor tol_tensor;
224
- tol_tensor.Resize (dim_out);
225
- dev_ctx.template Alloc <T>(&tol_tensor);
226
-
227
- if (use_default_tol) {
228
- // when `rtol` is specified to be None in py api
229
- // use tol=eps*max(m, n)*sigma_1 only if `atol` is specified as 0
230
- T rtol_T = std::numeric_limits<T>::epsilon () * std::max (rows, cols);
231
- DenseTensor default_rtol_tensor;
232
- default_rtol_tensor =
233
- phi::Full<T, Context>(dev_ctx, {1 }, static_cast <T>(rtol_T));
234
- default_rtol_tensor =
235
- phi::Multiply<T>(dev_ctx, default_rtol_tensor, max_eigenvalue_tensor);
236
-
237
- DenseTensor tmp_zeros;
238
- tmp_zeros = phi::Full<T, Context>(dev_ctx, {1 }, static_cast <T>(0.0 ));
239
- DenseTensor atol_compare_result;
240
- atol_compare_result.Resize (atol_tensor.dims ());
241
- phi::EqualKernel<T, Context>(
242
- dev_ctx, atol_tensor, tmp_zeros, &atol_compare_result);
243
-
244
- DenseTensor selected_rtol_tensor;
245
- selected_rtol_tensor.Resize (rtol_tensor.dims ());
246
- phi::WhereKernel<T, Context>(dev_ctx,
247
- atol_compare_result,
248
- default_rtol_tensor,
249
- rtol_tensor,
250
- &selected_rtol_tensor);
251
- funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
252
- dev_ctx,
253
- atol_tensor,
254
- selected_rtol_tensor,
255
- GreaterElementFunctor<T>(),
256
- &tol_tensor);
257
- } else {
258
- DenseTensor tmp_rtol_tensor;
259
- tmp_rtol_tensor =
260
- phi::Multiply<T>(dev_ctx, rtol_tensor, max_eigenvalue_tensor);
261
- funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
262
- dev_ctx,
263
- atol_tensor,
264
- tmp_rtol_tensor,
265
- GreaterElementFunctor<T>(),
266
- &tol_tensor);
267
- }
268
-
269
- tol_tensor.Resize (detail::NewAxisDim (tol_tensor.dims (), 1 ));
270
-
271
- DenseTensor compare_result;
272
- compare_result.Resize (detail::NewAxisDim (dim_out, k));
273
- dev_ctx.template Alloc <int64_t >(&compare_result);
274
- int axis = -1 ;
275
- if (eigenvalue_tensor.dims ().size () >= tol_tensor.dims ().size ()) {
276
- funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t >, T, int >(
277
- dev_ctx,
278
- eigenvalue_tensor,
279
- tol_tensor,
280
- funcs::GreaterThanFunctor<T, int64_t >(),
281
- &compare_result,
282
- axis);
283
- } else {
284
- funcs::ElementwiseCompute<funcs::LessThanFunctor<T, int64_t >, T, int >(
285
- dev_ctx,
286
- eigenvalue_tensor,
287
- tol_tensor,
288
- funcs::LessThanFunctor<T, int64_t >(),
289
- &compare_result,
290
- axis);
291
- }
292
-
293
- phi::SumKernel<int64_t >(dev_ctx,
294
- compare_result,
295
- std::vector<int64_t >{-1 },
296
- compare_result.dtype (),
297
- false ,
298
- out);
275
+ funcs::ElementwiseCompute<funcs::LessThanFunctor<T, int64_t >, T, int >(
276
+ dev_ctx,
277
+ eigenvalue_tensor,
278
+ tol_tensor,
279
+ funcs::LessThanFunctor<T, int64_t >(),
280
+ &compare_result,
281
+ axis);
299
282
}
283
+
284
+ phi::SumKernel<int64_t >(dev_ctx,
285
+ compare_result,
286
+ std::vector<int64_t >{-1 },
287
+ compare_result.dtype (),
288
+ false ,
289
+ out);
300
290
}
301
291
} // namespace phi
302
292
0 commit comments