Skip to content
Merged
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
fd28881
[Metax_change_ut]
duqimeng Jul 23, 2025
a9d2aa7
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Jul 24, 2025
1695f36
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Jul 31, 2025
b931d38
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 1, 2025
bef21bf
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 8, 2025
f4e5004
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 13, 2025
55422eb
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 18, 2025
815a63a
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 19, 2025
1739a15
fix sum&collect_fpn_proposals op register
StareAtYou Aug 19, 2025
af0bae5
fix sum&collect_fpn_proposals op register
metax666 Aug 19, 2025
be61f06
modify profile
jxwangmetax Aug 20, 2025
0fc2dd1
modify profile
metax666 Aug 20, 2025
1ad95c5
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 20, 2025
f12b3e4
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 21, 2025
789c9fc
[Metax] fix paddle bug replace 'MoeGradDispatchKernel' to 'MoeGateDis…
StareAtYou Aug 21, 2025
a0116fb
[Metax] fix paddle bug
metax666 Aug 21, 2025
a2da5e0
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 22, 2025
f9e6d2c
[Metax] register bce_loss_grad & bce_loss & index_add_grad kernels
StareAtYou Aug 22, 2025
4b4f562
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Aug 22, 2025
662e22e
[Metax] con2d_grad use gpudnn
duqimeng Aug 22, 2025
3e8d6ce
Merge branch 'metax666:develop' into develop
StareAtYou Aug 25, 2025
9dae9b7
[Metax] register bce_loss_grad & bce_loss & index_add_grad kernels
metax666 Aug 25, 2025
47fef62
blas handle support
jxwangmetax Aug 25, 2025
266c0df
blas handle support
metax666 Aug 25, 2025
a0b340b
[Metax] register some kernels & update CMakeLists
StareAtYou Aug 25, 2025
aa9bd35
Merge branch 'metax666:develop' into develop
StareAtYou Aug 26, 2025
8c6ac05
[Metax] register some kernels & update CMakeLists
metax666 Aug 26, 2025
9510f7d
Merge branch 'metax666:develop' into develop
duqimeng Aug 26, 2025
fa7cc1a
[Metax] fix metax unittest fail
StareAtYou Aug 26, 2025
a907545
[Metax] fix metax unittest fail
metax666 Aug 26, 2025
7a6312e
[Metax] add group_norm & label_smooth kernel and update matmul kernel
StareAtYou Aug 26, 2025
90bb94e
[Metax] add group_norm & label_smooth kernel and update matmul kernel
metax666 Aug 27, 2025
9f130fe
[Metax] fix rmsprop kernel register and add meshgrid & meshgrid_grad …
StareAtYou Aug 27, 2025
ca38fb5
Merge branch 'metax666:develop' into develop
StareAtYou Aug 27, 2025
f0cc1e0
add test
zhang-chenyi Aug 27, 2025
8e8b732
add test
zhang-chenyi Aug 27, 2025
8d7efbd
Merge branch 'metax666:develop' into develop
zhang-chenyi Aug 27, 2025
28c992b
Merge branch 'develop' of https://github.com/zhang-chenyi/PaddleCusto…
zhang-chenyi Aug 27, 2025
d3470bb
[test] chang the logic of workspace_host in cholesky_kernel_register
zhang-chenyi Aug 27, 2025
db17ebf
Merge branch 'develop' of https://github.com/zhang-chenyi/PaddleCusto…
zhang-chenyi Aug 27, 2025
83bc87f
[Metax] fix compile fail
StareAtYou Aug 27, 2025
f1e8d0c
Revert "[Metax] fix compile fail"
StareAtYou Aug 27, 2025
a13daa8
[Metax] fix compile fail by 'conv_transpose_grad_kernel_impl.h'
StareAtYou Aug 27, 2025
95a179b
[Metax] fix bug & add some kernel register
metax666 Aug 28, 2025
4576ef4
[Metax]fix bug and add qr lstsq logsoftmax
duqimeng Aug 28, 2025
ca51a1e
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
7789e9b
[Metax] con2d_grad use gpudnn
duqimeng Aug 22, 2025
afd0863
[Metax]fix bug and add qr lstsq logsoftmax
duqimeng Aug 28, 2025
6da0f0d
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
e1e07ba
[Metax] change_patch
duqimeng Aug 28, 2025
046637c
[Metax] change_patch
metax666 Aug 28, 2025
c27b492
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 28, 2025
05ecd9d
[Metax] update unit test CMakeLists.txt
StareAtYou Aug 28, 2025
b1bf7e8
[Metax] update unit test CMakeLists.txt
StareAtYou Aug 28, 2025
f90d585
Merge branch 'metax666:develop' into develop
StareAtYou Aug 28, 2025
874d9b6
Merge branch 'metax666:develop' into develop
zhang-chenyi Aug 28, 2025
0ca02b9
[feature] add unique_consecutive kernel
zhang-chenyi Aug 28, 2025
40d8f21
[metax-feature] add kernel for test_math_op_patch_var_base
metax666 Aug 28, 2025
3e9b526
[metax] add some kernel
duqimeng Aug 28, 2025
8911576
[metax] add some kernel
duqimeng Aug 28, 2025
8471597
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
0758887
Merge branch 'metax666:develop' into develop
StareAtYou Aug 29, 2025
61be33d
[Metax] register baddbmm kernel & update blas api
StareAtYou Aug 29, 2025
2fe962e
[Metax] register baddbmm kernel & update blas api
StareAtYou Aug 29, 2025
531fedb
Merge branch 'metax666:develop' into develop
StareAtYou Aug 29, 2025
c0dcfff
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
StareAtYou Aug 29, 2025
bd65451
[feature] add add unique_consecutive kernel.cu
zhang-chenyi Aug 29, 2025
0def63d
[fix] fix some test case due to missing op register
zhang-chenyi Aug 29, 2025
e503c9e
[fix] fix some fail text
zhang-chenyi Aug 29, 2025
9844878
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
70b86e7
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
1e90757
add and fix some kernels
1184319564 Aug 30, 2025
f93307d
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
StareAtYou Aug 29, 2025
c4b0eb9
[Metax] fix conflict
StareAtYou Sep 1, 2025
06dda18
[Metax] fix conflict
StareAtYou Sep 1, 2025
dae6ce8
[Metax] adapt to paddle-cpu-20250901 & resolve the issue of 'test_ele…
StareAtYou Sep 1, 2025
b4a5c62
[Metax] update repeat_interleave kernel & ignore max op test
StareAtYou Sep 2, 2025
7cf4405
Merge branch 'metax666:develop' into develop
StareAtYou Sep 2, 2025
0015f2e
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
metax666 Sep 2, 2025
fc2c0f5
Merge branch 'metax666:develop' into develop
duqimeng Sep 2, 2025
829c3b6
Merge dev
duqimeng Sep 2, 2025
3104a9c
【metax】add and fix some kernels
metax666 Sep 2, 2025
175cca6
[metax]fix lu eigvalshsqueeze rnn kernel
metax666 Sep 2, 2025
c7db810
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
f5813ed
[metax] chang patch fix copy
duqimeng Sep 2, 2025
6f0b705
[metax] chang patch fix copy
duqimeng Sep 2, 2025
8f47f0e
[metax] chang patch fix copy
metax666 Sep 2, 2025
b420f97
[Metax] update metax_gpu unit test
StareAtYou Sep 2, 2025
c08533e
[Metax] update metax_gpu unit test
metax666 Sep 2, 2025
414715f
[Metax] fix test CMakeList.txt
StareAtYou Sep 2, 2025
aa6b5bf
[Metax] fix test CMakeList.txt
metax666 Sep 2, 2025
0bfc6e7
[metax]change_cupti_and_fix_softmax
duqimeng Sep 9, 2025
cb93f6a
[metax]change_cupti_and_fix_softmax
duqimeng Sep 9, 2025
2e99f62
[metax]change_patch
duqimeng Sep 9, 2025
026551a
[metax]change_patch
duqimeng Sep 9, 2025
b09babb
Merge branch 'metax666:develop' into develop
duqimeng Sep 9, 2025
31594f8
[metax] updata_qr_kernel
duqimeng Sep 11, 2025
4fb467c
[metax] updata_qr_kernel
duqimeng Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 98 additions & 109 deletions backends/metax_gpu/kernels/metax_kernel/qr_kernel_register.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
#include <algorithm>
#include <vector>

#include "kernels/impl/values_vectors_functor.h"
#include "kernels/metax_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -39,7 +38,6 @@
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/phi/kernels/tril_triu_kernel.h"

namespace phi {

template <class T, class Context>
Expand Down Expand Up @@ -358,47 +356,47 @@ void QrKernel(const Context& dev_ctx,

#ifdef PADDLE_WITH_HIP
#define FUNC_WITH_TYPES(m) m(float, s) m(double, d)
#define GEQRF_BATCH_INSTANCE(T, C) \
template <> \
void BatchedGeqrf<GPUContext, T>(const GPUContext& dev_ctx, \
int batch_size, \
int m, \
int n, \
T* a, \
int lda, \
T* tau, \
int a_stride, \
int tau_stride) { \
auto handle = dev_ctx.cusolver_dn_handle(); \
for (int i = 0; i < batch_size; ++i) { \
T* a_working_ptr = &a[i * a_stride]; \
T* tau_working_ptr = &tau[i * tau_stride]; \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##geqrf( \
handle, m, n, a_working_ptr, lda, tau_working_ptr)); \
} \
#define GEQRF_BATCH_INSTANCE(T, C) \
template <> \
void BatchedGeqrf<GPUContext, T>(const GPUContext& dev_ctx, \
int batch_size, \
int m, \
int n, \
T* a, \
int lda, \
T* tau, \
int a_stride, \
int tau_stride) { \
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); \
for (int i = 0; i < batch_size; ++i) { \
T* a_working_ptr = &a[i * a_stride]; \
T* tau_working_ptr = &tau[i * tau_stride]; \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##geqrf( \
handle, m, n, a_working_ptr, lda, tau_working_ptr)); \
} \
}

FUNC_WITH_TYPES(GEQRF_BATCH_INSTANCE);

#define ORGQR_BATCH_INSTANCE(T, C) \
template <> \
void BatchedOrgqr<GPUContext, T>(const GPUContext& dev_ctx, \
int batch_size, \
int m, \
int n, \
int k, \
T* a, \
int lda, \
T* tau, \
int a_stride, \
int tau_stride) { \
auto handle = dev_ctx.cusolver_dn_handle(); \
for (int i = 0; i < batch_size; ++i) { \
T* a_working_ptr = &a[i * a_stride]; \
T* tau_working_ptr = &tau[i * tau_stride]; \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##orgqr( \
handle, m, n, k, a_working_ptr, lda, tau_working_ptr)); \
} \
#define ORGQR_BATCH_INSTANCE(T, C) \
template <> \
void BatchedOrgqr<GPUContext, T>(const GPUContext& dev_ctx, \
int batch_size, \
int m, \
int n, \
int k, \
T* a, \
int lda, \
T* tau, \
int a_stride, \
int tau_stride) { \
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); \
for (int i = 0; i < batch_size; ++i) { \
T* a_working_ptr = &a[i * a_stride]; \
T* tau_working_ptr = &tau[i * tau_stride]; \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##orgqr( \
handle, m, n, k, a_working_ptr, lda, tau_working_ptr)); \
} \
}

FUNC_WITH_TYPES(ORGQR_BATCH_INSTANCE);
Expand All @@ -421,7 +419,6 @@ void BatchedGeqrf<GPUContext, float>(const GPUContext& dev_ctx,
const int64_t a_stride_64 = static_cast<int64_t>(a_stride);
const int64_t tau_stride_64 = static_cast<int64_t>(tau_stride);

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());

size_t workspace_in_bytes_on_device = 0;
Expand Down Expand Up @@ -499,7 +496,6 @@ void BatchedGeqrf<GPUContext, float>(const GPUContext& dev_ctx,
} else {
int lwork = 0;

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf_bufferSize(
handle, m, n, a, lda, &lwork));
Expand Down Expand Up @@ -555,7 +551,6 @@ void BatchedGeqrf<GPUContext, double>(const GPUContext& dev_ctx,
int tau_stride) {
int lwork = 0;

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork));
Expand Down Expand Up @@ -599,35 +594,33 @@ void BatchedGeqrf<GPUContext, double>(const GPUContext& dev_ctx,
}

template <>
void BatchedGeqrf<GPUContext, phi::dtype::complex<float>>(
const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
phi::dtype::complex<float>* a,
int lda,
phi::dtype::complex<float>* tau,
int a_stride,
int tau_stride) {
void BatchedGeqrf<GPUContext, phi::complex64>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
phi::complex64* a,
int lda,
phi::complex64* tau,
int a_stride,
int tau_stride) {
int lwork = 0;

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgeqrf_bufferSize(
handle, m, n, reinterpret_cast<cuComplex*>(a), lda, &lwork));

DenseTensor workspace = DenseTensor();
workspace.Resize(common::make_ddim({lwork}));
phi::dtype::complex<float>* workspace_ptr =
dev_ctx.template Alloc<phi::dtype::complex<float>>(&workspace);
phi::complex64* workspace_ptr =
dev_ctx.template Alloc<phi::complex64>(&workspace);

DenseTensor info = DenseTensor();
info.Resize(common::make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(&info);

for (int i = 0; i < batch_size; ++i) {
phi::dtype::complex<float>* a_working_ptr = &a[i * a_stride];
phi::dtype::complex<float>* tau_working_ptr = &tau[i * tau_stride];
phi::complex64* a_working_ptr = &a[i * a_stride];
phi::complex64* tau_working_ptr = &tau[i * tau_stride];
// compute geqrf
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgeqrf(
handle,
Expand Down Expand Up @@ -657,35 +650,33 @@ void BatchedGeqrf<GPUContext, phi::dtype::complex<float>>(
}

template <>
void BatchedGeqrf<GPUContext, phi::dtype::complex<double>>(
const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
phi::dtype::complex<double>* a,
int lda,
phi::dtype::complex<double>* tau,
int a_stride,
int tau_stride) {
void BatchedGeqrf<GPUContext, phi::complex128>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
phi::complex128* a,
int lda,
phi::complex128* tau,
int a_stride,
int tau_stride) {
int lwork = 0;

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgeqrf_bufferSize(
handle, m, n, reinterpret_cast<cuDoubleComplex*>(a), lda, &lwork));

DenseTensor workspace = DenseTensor();
workspace.Resize(common::make_ddim({lwork}));
phi::dtype::complex<double>* workspace_ptr =
dev_ctx.template Alloc<phi::dtype::complex<double>>(&workspace);
phi::complex128* workspace_ptr =
dev_ctx.template Alloc<phi::complex128>(&workspace);

DenseTensor info = DenseTensor();
info.Resize(common::make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(&info);

for (int i = 0; i < batch_size; ++i) {
phi::dtype::complex<double>* a_working_ptr = &a[i * a_stride];
phi::dtype::complex<double>* tau_working_ptr = &tau[i * tau_stride];
phi::complex128* a_working_ptr = &a[i * a_stride];
phi::complex128* tau_working_ptr = &tau[i * tau_stride];
// compute geqrf
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgeqrf(
handle,
Expand Down Expand Up @@ -727,7 +718,6 @@ void BatchedOrgqr<GPUContext, float>(const GPUContext& dev_ctx,
int tau_stride) {
int lwork = 0;

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize(
handle, m, n, k, a, lda, tau, &lwork));
Expand Down Expand Up @@ -784,7 +774,6 @@ void BatchedOrgqr<GPUContext, double>(const GPUContext& dev_ctx,
int tau_stride) {
int lwork = 0;

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize(
handle, m, n, k, a, lda, tau, &lwork));
Expand Down Expand Up @@ -829,20 +818,18 @@ void BatchedOrgqr<GPUContext, double>(const GPUContext& dev_ctx,
}

template <>
void BatchedOrgqr<GPUContext, phi::dtype::complex<float>>(
const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
phi::dtype::complex<float>* a,
int lda,
phi::dtype::complex<float>* tau,
int a_stride,
int tau_stride) {
void BatchedOrgqr<GPUContext, phi::complex64>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
phi::complex64* a,
int lda,
phi::complex64* tau,
int a_stride,
int tau_stride) {
int lwork = 0;

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCungqr_bufferSize(
handle,
Expand All @@ -856,16 +843,16 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<float>>(

DenseTensor workspace = DenseTensor();
workspace.Resize(common::make_ddim({lwork}));
phi::dtype::complex<float>* workspace_ptr =
dev_ctx.template Alloc<phi::dtype::complex<float>>(&workspace);
phi::complex64* workspace_ptr =
dev_ctx.template Alloc<phi::complex64>(&workspace);

DenseTensor info = DenseTensor();
info.Resize(common::make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(&info);

for (int i = 0; i < batch_size; ++i) {
phi::dtype::complex<float>* a_working_ptr = &a[i * a_stride];
phi::dtype::complex<float>* tau_working_ptr = &tau[i * tau_stride];
phi::complex64* a_working_ptr = &a[i * a_stride];
phi::complex64* tau_working_ptr = &tau[i * tau_stride];
// compute orggr
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCungqr(
handle,
Expand Down Expand Up @@ -896,20 +883,18 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<float>>(
}

template <>
void BatchedOrgqr<GPUContext, phi::dtype::complex<double>>(
const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
phi::dtype::complex<double>* a,
int lda,
phi::dtype::complex<double>* tau,
int a_stride,
int tau_stride) {
void BatchedOrgqr<GPUContext, phi::complex128>(const GPUContext& dev_ctx,
int batch_size,
int m,
int n,
int k,
phi::complex128* a,
int lda,
phi::complex128* tau,
int a_stride,
int tau_stride) {
int lwork = 0;

// auto handle = dev_ctx.cusolver_dn_handle();
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZungqr_bufferSize(
handle,
Expand All @@ -923,16 +908,16 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<double>>(

DenseTensor workspace = DenseTensor();
workspace.Resize(common::make_ddim({lwork}));
phi::dtype::complex<double>* workspace_ptr =
dev_ctx.template Alloc<phi::dtype::complex<double>>(&workspace);
phi::complex128* workspace_ptr =
dev_ctx.template Alloc<phi::complex128>(&workspace);

DenseTensor info = DenseTensor();
info.Resize(common::make_ddim({1}));
int* info_d = dev_ctx.template Alloc<int>(&info);

for (int i = 0; i < batch_size; ++i) {
phi::dtype::complex<double>* a_working_ptr = &a[i * a_stride];
phi::dtype::complex<double>* tau_working_ptr = &tau[i * tau_stride];
phi::complex128* a_working_ptr = &a[i * a_stride];
phi::complex128* tau_working_ptr = &tau[i * tau_stride];
// compute orggr
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZungqr(
handle,
Expand Down Expand Up @@ -965,11 +950,15 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<double>>(

} // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(qr, GPU, ALL_LAYOUT, phi::QrKernel, float, double) {}
#else
PD_REGISTER_PLUGIN_KERNEL(qr,
metax_gpu,
ALL_LAYOUT,
phi::QrKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::complex64,
phi::complex128) {}
#endif