@@ -1128,6 +1128,108 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
11281128 });
11291129}
11301130
1131+ // note(wangran16): unknown bug. parameters dislocation when calling
1132+ // GEMM_STRIDED_BATCH<float> and GEMM_STRIDED_BATCH<double>
1133+ template <>
1134+ template <>
1135+ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
1136+ CBLAS_TRANSPOSE transB,
1137+ int M,
1138+ int N,
1139+ int K,
1140+ float alpha,
1141+ const float *A,
1142+ const float *B,
1143+ float beta,
1144+ float *C,
1145+ int batchCount,
1146+ int64_t strideA,
1147+ int64_t strideB) const {
1148+ // Note that cublas follows fortran order, so the order is different from
1149+ // the cblas convention.
1150+ int lda = (transA == CblasNoTrans) ? K : M;
1151+ int ldb = (transB == CblasNoTrans) ? N : K;
1152+ int ldc = N;
1153+ rocblas_operation cuTransA = (transA == CblasNoTrans)
1154+ ? rocblas_operation_none
1155+ : rocblas_operation_transpose;
1156+ rocblas_operation cuTransB = (transB == CblasNoTrans)
1157+ ? rocblas_operation_none
1158+ : rocblas_operation_transpose;
1159+ const int64_t strideC = M * N;
1160+ context_.CublasCall ([&](rocblas_handle handle) {
1161+ PADDLE_ENFORCE_GPU_SUCCESS (
1162+ paddle::platform::dynload::rocblas_sgemm_strided_batched (handle,
1163+ cuTransB,
1164+ cuTransA,
1165+ N,
1166+ M,
1167+ K,
1168+ &alpha,
1169+ B,
1170+ ldb,
1171+ strideB,
1172+ A,
1173+ lda,
1174+ strideA,
1175+ &beta,
1176+ C,
1177+ ldc,
1178+ strideC,
1179+ batchCount));
1180+ });
1181+ }
1182+
1183+ template <>
1184+ template <>
1185+ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
1186+ CBLAS_TRANSPOSE transB,
1187+ int M,
1188+ int N,
1189+ int K,
1190+ double alpha,
1191+ const double *A,
1192+ const double *B,
1193+ double beta,
1194+ double *C,
1195+ int batchCount,
1196+ int64_t strideA,
1197+ int64_t strideB) const {
1198+ // Note that cublas follows fortran order, so the order is different from
1199+ // the cblas convention.
1200+ int lda = (transA == CblasNoTrans) ? K : M;
1201+ int ldb = (transB == CblasNoTrans) ? N : K;
1202+ int ldc = N;
1203+ rocblas_operation cuTransA = (transA == CblasNoTrans)
1204+ ? rocblas_operation_none
1205+ : rocblas_operation_transpose;
1206+ rocblas_operation cuTransB = (transB == CblasNoTrans)
1207+ ? rocblas_operation_none
1208+ : rocblas_operation_transpose;
1209+ const int64_t strideC = M * N;
1210+ context_.CublasCall ([&](rocblas_handle handle) {
1211+ PADDLE_ENFORCE_GPU_SUCCESS (
1212+ paddle::platform::dynload::rocblas_dgemm_strided_batched (handle,
1213+ cuTransB,
1214+ cuTransA,
1215+ N,
1216+ M,
1217+ K,
1218+ &alpha,
1219+ B,
1220+ ldb,
1221+ strideB,
1222+ A,
1223+ lda,
1224+ strideA,
1225+ &beta,
1226+ C,
1227+ ldc,
1228+ strideC,
1229+ batchCount));
1230+ });
1231+ }
1232+
11311233template <>
11321234template <>
11331235inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
0 commit comments