Skip to content

Commit a02d1aa

Browse files
authored
[深度对齐] dot (#75717)
* fix * fix * fix dcu
1 parent 1990bcc commit a02d1aa

File tree

4 files changed

+139
-3
lines changed

4 files changed

+139
-3
lines changed

paddle/phi/backends/dynload/cublas.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,14 @@ extern void *cublas_dso_handle;
106106
__macro(cublasCmatinvBatched); \
107107
__macro(cublasZmatinvBatched); \
108108
__macro(cublasSgetrsBatched); \
109-
__macro(cublasDgetrsBatched);
109+
__macro(cublasDgetrsBatched); \
110+
__macro(cublasSdot_v2); \
111+
__macro(cublasDdot_v2); \
112+
__macro(cublasCdotc_v2); \
113+
__macro(cublasZdotc_v2); \
114+
__macro(cublasCdotu_v2); \
115+
__macro(cublasZdotu_v2); \
116+
__macro(cublasDotEx);
110117

111118
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
112119

paddle/phi/kernels/funcs/blas/blas.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ class Blas {
283283
template <typename T>
284284
T DOT(int n, const T* x, const T* y) const;
285285

286+
template <typename T>
287+
void CUDOT(
288+
int n, const T* x, int incx, const T* y, int incy, T* result) const;
289+
286290
template <typename T>
287291
void SCAL(int n, const T a, T* x) const;
288292

@@ -543,6 +547,11 @@ class BlasT : private Blas<DeviceContext> {
543547
return Base()->template DOT<T>(args...);
544548
}
545549

550+
template <typename... ARGS>
551+
void CUDOT(ARGS... args) const {
552+
Base()->template CUDOT<T>(args...);
553+
}
554+
546555
template <typename... ARGS>
547556
void SCAL(ARGS... args) const {
548557
Base()->template SCAL<T>(args...);

paddle/phi/kernels/funcs/blas/blas_impl.cu.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ struct CUBlas<float> {
211211
static void TRSM_BATCH(ARGS... args) {
212212
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsmBatched(args...));
213213
}
214+
215+
template <typename... ARGS>
216+
static void DOT(ARGS... args) {
217+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSdot_v2(args...));
218+
}
214219
};
215220

216221
template <>
@@ -302,6 +307,11 @@ struct CUBlas<double> {
302307
static void TRSM_BATCH(ARGS... args) {
303308
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsmBatched(args...));
304309
}
310+
311+
template <typename... ARGS>
312+
static void DOT(ARGS... args) {
313+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDdot_v2(args...));
314+
}
305315
};
306316

307317
template <>
@@ -559,6 +569,26 @@ struct CUBlas<phi::float16> {
559569
"cublasGemmEx_64 is not supported on cuda < 12.3"));
560570
#endif
561571
}
572+
573+
static void DOT(cublasHandle_t handle,
574+
int n,
575+
const phi::float16 *x,
576+
const int incx,
577+
const phi::float16 *y,
578+
const int incy,
579+
phi::float16 *result) {
580+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDotEx(handle,
581+
n,
582+
x,
583+
CUDA_R_16F,
584+
incx,
585+
y,
586+
CUDA_R_16F,
587+
incy,
588+
result,
589+
CUDA_R_16F,
590+
CUDA_R_32F));
591+
}
562592
};
563593

564594
template <>
@@ -908,6 +938,23 @@ struct CUBlas<phi::complex64> {
908938
info,
909939
batch_size));
910940
}
941+
942+
static void DOT(cublasHandle_t handle,
943+
int n,
944+
const phi::complex64 *x,
945+
const int incx,
946+
const phi::complex64 *y,
947+
const int incy,
948+
phi::complex64 *result) {
949+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCdotu_v2(
950+
handle,
951+
n,
952+
reinterpret_cast<const cuFloatComplex *>(x),
953+
incx,
954+
reinterpret_cast<const cuFloatComplex *>(y),
955+
incy,
956+
reinterpret_cast<cuFloatComplex *>(result)));
957+
}
911958
};
912959

913960
template <>
@@ -1257,6 +1304,23 @@ struct CUBlas<phi::complex128> {
12571304
info,
12581305
batch_size));
12591306
}
1307+
1308+
static void DOT(cublasHandle_t handle,
1309+
int n,
1310+
const phi::complex128 *x,
1311+
const int incx,
1312+
const phi::complex128 *y,
1313+
const int incy,
1314+
phi::complex128 *result) {
1315+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZdotu_v2(
1316+
handle,
1317+
n,
1318+
reinterpret_cast<const cuDoubleComplex *>(x),
1319+
incx,
1320+
reinterpret_cast<const cuDoubleComplex *>(y),
1321+
incy,
1322+
reinterpret_cast<cuDoubleComplex *>(result)));
1323+
}
12601324
};
12611325

12621326
inline void CheckGEMMNSize(int64_t N) {
@@ -2289,6 +2353,38 @@ void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
22892353
});
22902354
}
22912355

2356+
template <>
2357+
template <typename T>
2358+
void Blas<phi::GPUContext>::CUDOT(
2359+
int n, const T *x, int incx, const T *y, int incy, T *result) const {
2360+
dev_ctx_.CublasCall([&](cublasHandle_t handle) {
2361+
CUBlas<T>::DOT(handle, n, x, incx, y, incy, result);
2362+
});
2363+
}
2364+
2365+
template <>
2366+
template <>
2367+
inline void Blas<phi::GPUContext>::CUDOT(int n,
2368+
const phi::bfloat16 *x,
2369+
int incx,
2370+
const phi::bfloat16 *y,
2371+
int incy,
2372+
phi::bfloat16 *result) const {
2373+
dev_ctx_.CublasCall([&](cublasHandle_t handle) {
2374+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDotEx(handle,
2375+
n,
2376+
x,
2377+
CUDA_R_16BF,
2378+
incx,
2379+
y,
2380+
CUDA_R_16BF,
2381+
incy,
2382+
result,
2383+
CUDA_R_16BF,
2384+
CUDA_R_32F));
2385+
});
2386+
}
2387+
22922388
template <>
22932389
template <typename T>
22942390
void Blas<phi::GPUContext>::SCAL(int n, const T alpha, T *x) const {

paddle/phi/kernels/gpu/dot_kernel.cu

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/dot_kernel.h"
16-
1716
#include "paddle/phi/backends/gpu/gpu_context.h"
1817
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/funcs/blas/blas.h"
1919
#include "paddle/phi/kernels/funcs/eigen/common.h"
2020

2121
#include "paddle/phi/kernels/full_kernel.h"
@@ -36,14 +36,39 @@ void DotKernel(const Context& dev_ctx,
3636
if (out->numel() <= 0) {
3737
return;
3838
}
39+
auto x_data = x.data<T>();
40+
auto y_data = y.data<T>();
3941
dev_ctx.template Alloc<T>(out);
42+
auto out_data = out->data<T>();
4043
if (out->dims().size() == 0) {
44+
#ifdef PADDLE_WITH_CUDA
45+
if constexpr (std::is_same_v<T, int> || std::is_same_v<T, int64_t>) {
46+
auto eigen_out = phi::EigenScalar<T>::From(*out);
47+
auto eigen_x = phi::EigenVector<T>::Flatten(x);
48+
auto eigen_y = phi::EigenVector<T>::Flatten(y);
49+
50+
auto& dev = *dev_ctx.eigen_device();
51+
eigen_out.device(dev) = (eigen_x * eigen_y).sum();
52+
} else {
53+
const int n = static_cast<int>(x.numel());
54+
int incx = static_cast<int>(x.strides()[0]);
55+
int incy = static_cast<int>(x.strides()[0]);
56+
if (n == 1) {
57+
incx = 1;
58+
incy = 1;
59+
}
60+
61+
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx);
62+
blas.CUDOT(n, x_data, incx, y_data, incy, out_data);
63+
}
64+
#else
4165
auto eigen_out = phi::EigenScalar<T>::From(*out);
4266
auto eigen_x = phi::EigenVector<T>::Flatten(x);
4367
auto eigen_y = phi::EigenVector<T>::Flatten(y);
4468

4569
auto& dev = *dev_ctx.eigen_device();
4670
eigen_out.device(dev) = (eigen_x * eigen_y).sum();
71+
#endif
4772
} else {
4873
auto eigen_out = phi::EigenVector<T>::From(*out);
4974
auto eigen_x = phi::EigenMatrix<T>::From(x);
@@ -53,7 +78,6 @@ void DotKernel(const Context& dev_ctx,
5378
eigen_out.device(dev) = (eigen_x * eigen_y).sum(Eigen::DSizes<int, 1>(1));
5479
}
5580
}
56-
5781
} // namespace phi
5882

5983
using complex64 = phi::complex64;

0 commit comments

Comments
 (0)