Skip to content

Commit 0b2a730

Browse files
committed
update kernel
1 parent 3248f18 commit 0b2a730

File tree

11 files changed

+365
-505
lines changed

11 files changed

+365
-505
lines changed

paddle/phi/infermeta/multiary.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3845,20 +3845,6 @@ void LogspaceInferMeta(const MetaTensor& start,
38453845
out->set_dtype(dtype);
38463846
}
38473847

3848-
void MatrixRankAtolRtolInferMeta(const MetaTensor& x,
3849-
const MetaTensor& tol,
3850-
const MetaTensor& atol,
3851-
const MetaTensor& rtol,
3852-
bool use_default_tol,
3853-
bool hermitian,
3854-
MetaTensor* out) {
3855-
if (tol) {
3856-
MatrixRankTolInferMeta(x, tol, use_default_tol, hermitian, out);
3857-
} else {
3858-
MatrixRankTolInferMeta(x, atol, use_default_tol, hermitian, out);
3859-
}
3860-
}
3861-
38623848
void MergedAdamInferMeta(
38633849
const std::vector<const MetaTensor*>& param,
38643850
const std::vector<const MetaTensor*>& grad,

paddle/phi/infermeta/multiary.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -704,14 +704,6 @@ void LogspaceInferMeta(const MetaTensor& start,
704704
DataType dtype,
705705
MetaTensor* out);
706706

707-
void MatrixRankAtolRtolInferMeta(const MetaTensor& x,
708-
const MetaTensor& tol,
709-
const MetaTensor& atol,
710-
const MetaTensor& rtol,
711-
bool use_default_tol,
712-
bool hermitian,
713-
MetaTensor* out);
714-
715707
void MergedAdamInferMeta(
716708
const std::vector<const MetaTensor*>& param,
717709
const std::vector<const MetaTensor*>& grad,

paddle/phi/infermeta/ternary.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/common/layout.h"
2222
#include "paddle/phi/core/ddim.h"
2323
#include "paddle/phi/core/enforce.h"
24+
#include "paddle/phi/infermeta/binary.h"
2425
#include "paddle/phi/kernels/funcs/common_shape.h"
2526
#include "paddle/phi/kernels/impl/box_coder.h"
2627

@@ -1305,6 +1306,14 @@ void MatchMatrixTensorInferMeta(const MetaTensor& x,
13051306
}
13061307
}
13071308

1309+
void MatrixRankAtolRtolInferMeta(const MetaTensor& x,
1310+
const MetaTensor& atol,
1311+
const MetaTensor& rtol,
1312+
bool hermitian,
1313+
MetaTensor* out) {
1314+
MatrixRankTolInferMeta(x, atol, true, hermitian, out);
1315+
}
1316+
13081317
void MultiClassNMSInferMeta(const MetaTensor& bboxes,
13091318
const MetaTensor& scores,
13101319
const MetaTensor& rois_num,

paddle/phi/infermeta/ternary.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ void MatchMatrixTensorInferMeta(const MetaTensor& x,
231231
MetaTensor* tmp,
232232
MetaConfig config = MetaConfig());
233233

234+
void MatrixRankAtolRtolInferMeta(const MetaTensor& x,
235+
const MetaTensor& atol,
236+
const MetaTensor& rtol,
237+
bool hermitian,
238+
MetaTensor* out);
239+
234240
void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x,
235241
const MetaTensor& in_accum,
236242
const MetaTensor& in_state,

paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc

Lines changed: 107 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -175,128 +175,118 @@ void MatrixRankTolKernel(const Context& dev_ctx,
175175
template <typename T, typename Context>
176176
void MatrixRankAtolRtolKernel(const Context& dev_ctx,
177177
const DenseTensor& x,
178-
const paddle::optional<DenseTensor>& tol,
179-
const paddle::optional<DenseTensor>& atol,
178+
const DenseTensor& atol,
180179
const paddle::optional<DenseTensor>& rtol,
181-
bool use_default_tol,
182180
bool hermitian,
183181
DenseTensor* out) {
184-
if (tol) {
185-
MatrixRankTolKernel<T, Context>(
186-
dev_ctx, x, *tol, use_default_tol, hermitian, out);
182+
dev_ctx.template Alloc<int64_t>(out);
183+
auto dim_x = x.dims();
184+
auto dim_out = out->dims();
185+
int rows = static_cast<int>(dim_x[dim_x.size() - 2]);
186+
int cols = static_cast<int>(dim_x[dim_x.size() - 1]);
187+
int k = std::min(rows, cols);
188+
int batches = static_cast<int>(x.numel() / (rows * cols));
189+
190+
DenseTensor eigenvalue_tensor;
191+
eigenvalue_tensor.Resize(detail::GetEigenvalueDim(dim_x, k));
192+
auto* eigenvalue_data = dev_ctx.template Alloc<T>(&eigenvalue_tensor);
193+
194+
if (hermitian) {
195+
phi::funcs::MatrixEighFunctor<Context, T> functor;
196+
functor(dev_ctx, x, &eigenvalue_tensor, nullptr, true, false);
197+
phi::AbsKernel<T, Context>(dev_ctx, eigenvalue_tensor, &eigenvalue_tensor);
198+
} else {
199+
DenseTensor trans_x = phi::TransposeLast2Dim<T>(dev_ctx, x);
200+
auto* x_data = trans_x.data<T>();
201+
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols);
202+
}
203+
204+
DenseTensor max_eigenvalue_tensor;
205+
max_eigenvalue_tensor.Resize(detail::RemoveLastDim(eigenvalue_tensor.dims()));
206+
dev_ctx.template Alloc<T>(&max_eigenvalue_tensor);
207+
phi::MaxKernel<T, Context>(dev_ctx,
208+
eigenvalue_tensor,
209+
phi::IntArray({-1}),
210+
false,
211+
&max_eigenvalue_tensor);
212+
213+
DenseTensor tol_tensor;
214+
tol_tensor.Resize(dim_out);
215+
dev_ctx.template Alloc<T>(&tol_tensor);
216+
217+
if (rtol) {
218+
DenseTensor tmp_rtol_tensor;
219+
tmp_rtol_tensor = phi::Multiply<T>(dev_ctx, *rtol, max_eigenvalue_tensor);
220+
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
221+
dev_ctx,
222+
atol,
223+
tmp_rtol_tensor,
224+
GreaterElementFunctor<T>(),
225+
&tol_tensor);
226+
} else {
227+
// when `rtol` is specified to be None in py api
228+
// use rtol=eps*max(m, n) only if `atol` is passed with value 0.0, else use
229+
// rtol=0.0
230+
DenseTensor zero_tensor;
231+
zero_tensor = phi::FullLike<T, Context>(dev_ctx, atol, static_cast<T>(0.0));
232+
233+
T rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
234+
DenseTensor default_rtol_tensor;
235+
default_rtol_tensor =
236+
phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(rtol_T));
237+
default_rtol_tensor =
238+
phi::Multiply<T>(dev_ctx, default_rtol_tensor, max_eigenvalue_tensor);
239+
240+
DenseTensor atol_compare_result;
241+
atol_compare_result.Resize(atol.dims());
242+
phi::EqualKernel<T, Context>(
243+
dev_ctx, atol, zero_tensor, &atol_compare_result);
244+
245+
DenseTensor selected_rtol_tensor;
246+
selected_rtol_tensor.Resize(atol.dims());
247+
phi::WhereKernel<T, Context>(dev_ctx,
248+
atol_compare_result,
249+
default_rtol_tensor,
250+
zero_tensor,
251+
&selected_rtol_tensor);
252+
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
253+
dev_ctx,
254+
atol,
255+
selected_rtol_tensor,
256+
GreaterElementFunctor<T>(),
257+
&tol_tensor);
258+
}
259+
260+
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
261+
262+
DenseTensor compare_result;
263+
compare_result.Resize(detail::NewAxisDim(dim_out, k));
264+
dev_ctx.template Alloc<int64_t>(&compare_result);
265+
int axis = -1;
266+
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
267+
funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t>, T, int>(
268+
dev_ctx,
269+
eigenvalue_tensor,
270+
tol_tensor,
271+
funcs::GreaterThanFunctor<T, int64_t>(),
272+
&compare_result,
273+
axis);
187274
} else {
188-
DenseTensor atol_tensor = *atol;
189-
DenseTensor rtol_tensor = *rtol;
190-
dev_ctx.template Alloc<int64_t>(out);
191-
auto dim_x = x.dims();
192-
auto dim_out = out->dims();
193-
int rows = static_cast<int>(dim_x[dim_x.size() - 2]);
194-
int cols = static_cast<int>(dim_x[dim_x.size() - 1]);
195-
int k = std::min(rows, cols);
196-
int batches = static_cast<int>(x.numel() / (rows * cols));
197-
198-
DenseTensor eigenvalue_tensor;
199-
eigenvalue_tensor.Resize(detail::GetEigenvalueDim(dim_x, k));
200-
auto* eigenvalue_data = dev_ctx.template Alloc<T>(&eigenvalue_tensor);
201-
202-
if (hermitian) {
203-
phi::funcs::MatrixEighFunctor<Context, T> functor;
204-
functor(dev_ctx, x, &eigenvalue_tensor, nullptr, true, false);
205-
phi::AbsKernel<T, Context>(
206-
dev_ctx, eigenvalue_tensor, &eigenvalue_tensor);
207-
} else {
208-
DenseTensor trans_x = phi::TransposeLast2Dim<T>(dev_ctx, x);
209-
auto* x_data = trans_x.data<T>();
210-
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols);
211-
}
212-
213-
DenseTensor max_eigenvalue_tensor;
214-
max_eigenvalue_tensor.Resize(
215-
detail::RemoveLastDim(eigenvalue_tensor.dims()));
216-
dev_ctx.template Alloc<T>(&max_eigenvalue_tensor);
217-
phi::MaxKernel<T, Context>(dev_ctx,
218-
eigenvalue_tensor,
219-
phi::IntArray({-1}),
220-
false,
221-
&max_eigenvalue_tensor);
222-
223-
DenseTensor tol_tensor;
224-
tol_tensor.Resize(dim_out);
225-
dev_ctx.template Alloc<T>(&tol_tensor);
226-
227-
if (use_default_tol) {
228-
// when `rtol` is specified to be None in py api
229-
// use tol=eps*max(m, n)*sigma_1 only if `atol` is specified as 0
230-
T rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
231-
DenseTensor default_rtol_tensor;
232-
default_rtol_tensor =
233-
phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(rtol_T));
234-
default_rtol_tensor =
235-
phi::Multiply<T>(dev_ctx, default_rtol_tensor, max_eigenvalue_tensor);
236-
237-
DenseTensor tmp_zeros;
238-
tmp_zeros = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(0.0));
239-
DenseTensor atol_compare_result;
240-
atol_compare_result.Resize(atol_tensor.dims());
241-
phi::EqualKernel<T, Context>(
242-
dev_ctx, atol_tensor, tmp_zeros, &atol_compare_result);
243-
244-
DenseTensor selected_rtol_tensor;
245-
selected_rtol_tensor.Resize(rtol_tensor.dims());
246-
phi::WhereKernel<T, Context>(dev_ctx,
247-
atol_compare_result,
248-
default_rtol_tensor,
249-
rtol_tensor,
250-
&selected_rtol_tensor);
251-
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
252-
dev_ctx,
253-
atol_tensor,
254-
selected_rtol_tensor,
255-
GreaterElementFunctor<T>(),
256-
&tol_tensor);
257-
} else {
258-
DenseTensor tmp_rtol_tensor;
259-
tmp_rtol_tensor =
260-
phi::Multiply<T>(dev_ctx, rtol_tensor, max_eigenvalue_tensor);
261-
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
262-
dev_ctx,
263-
atol_tensor,
264-
tmp_rtol_tensor,
265-
GreaterElementFunctor<T>(),
266-
&tol_tensor);
267-
}
268-
269-
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
270-
271-
DenseTensor compare_result;
272-
compare_result.Resize(detail::NewAxisDim(dim_out, k));
273-
dev_ctx.template Alloc<int64_t>(&compare_result);
274-
int axis = -1;
275-
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
276-
funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t>, T, int>(
277-
dev_ctx,
278-
eigenvalue_tensor,
279-
tol_tensor,
280-
funcs::GreaterThanFunctor<T, int64_t>(),
281-
&compare_result,
282-
axis);
283-
} else {
284-
funcs::ElementwiseCompute<funcs::LessThanFunctor<T, int64_t>, T, int>(
285-
dev_ctx,
286-
eigenvalue_tensor,
287-
tol_tensor,
288-
funcs::LessThanFunctor<T, int64_t>(),
289-
&compare_result,
290-
axis);
291-
}
292-
293-
phi::SumKernel<int64_t>(dev_ctx,
294-
compare_result,
295-
std::vector<int64_t>{-1},
296-
compare_result.dtype(),
297-
false,
298-
out);
275+
funcs::ElementwiseCompute<funcs::LessThanFunctor<T, int64_t>, T, int>(
276+
dev_ctx,
277+
eigenvalue_tensor,
278+
tol_tensor,
279+
funcs::LessThanFunctor<T, int64_t>(),
280+
&compare_result,
281+
axis);
299282
}
283+
284+
phi::SumKernel<int64_t>(dev_ctx,
285+
compare_result,
286+
std::vector<int64_t>{-1},
287+
compare_result.dtype(),
288+
false,
289+
out);
300290
}
301291
} // namespace phi
302292

0 commit comments

Comments
 (0)