Skip to content

Commit 4478389

Browse files
authored
[ROCM] fix bmm_kernel (#45530)
1 parent 4a25b60 commit 4478389

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

paddle/phi/kernels/funcs/blas/blas_impl.hip.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,108 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
11281128
});
11291129
}
11301130

1131+
// note(wangran16): unknown bug. parameters dislocation when calling
1132+
// GEMM_STRIDED_BATCH<float> and GEMM_STRIDED_BATCH<double>
1133+
template <>
1134+
template <>
1135+
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
1136+
CBLAS_TRANSPOSE transB,
1137+
int M,
1138+
int N,
1139+
int K,
1140+
float alpha,
1141+
const float *A,
1142+
const float *B,
1143+
float beta,
1144+
float *C,
1145+
int batchCount,
1146+
int64_t strideA,
1147+
int64_t strideB) const {
1148+
// Note that cublas follows fortran order, so the order is different from
1149+
// the cblas convention.
1150+
int lda = (transA == CblasNoTrans) ? K : M;
1151+
int ldb = (transB == CblasNoTrans) ? N : K;
1152+
int ldc = N;
1153+
rocblas_operation cuTransA = (transA == CblasNoTrans)
1154+
? rocblas_operation_none
1155+
: rocblas_operation_transpose;
1156+
rocblas_operation cuTransB = (transB == CblasNoTrans)
1157+
? rocblas_operation_none
1158+
: rocblas_operation_transpose;
1159+
const int64_t strideC = M * N;
1160+
context_.CublasCall([&](rocblas_handle handle) {
1161+
PADDLE_ENFORCE_GPU_SUCCESS(
1162+
paddle::platform::dynload::rocblas_sgemm_strided_batched(handle,
1163+
cuTransB,
1164+
cuTransA,
1165+
N,
1166+
M,
1167+
K,
1168+
&alpha,
1169+
B,
1170+
ldb,
1171+
strideB,
1172+
A,
1173+
lda,
1174+
strideA,
1175+
&beta,
1176+
C,
1177+
ldc,
1178+
strideC,
1179+
batchCount));
1180+
});
1181+
}
1182+
1183+
template <>
1184+
template <>
1185+
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
1186+
CBLAS_TRANSPOSE transB,
1187+
int M,
1188+
int N,
1189+
int K,
1190+
double alpha,
1191+
const double *A,
1192+
const double *B,
1193+
double beta,
1194+
double *C,
1195+
int batchCount,
1196+
int64_t strideA,
1197+
int64_t strideB) const {
1198+
// Note that cublas follows fortran order, so the order is different from
1199+
// the cblas convention.
1200+
int lda = (transA == CblasNoTrans) ? K : M;
1201+
int ldb = (transB == CblasNoTrans) ? N : K;
1202+
int ldc = N;
1203+
rocblas_operation cuTransA = (transA == CblasNoTrans)
1204+
? rocblas_operation_none
1205+
: rocblas_operation_transpose;
1206+
rocblas_operation cuTransB = (transB == CblasNoTrans)
1207+
? rocblas_operation_none
1208+
: rocblas_operation_transpose;
1209+
const int64_t strideC = M * N;
1210+
context_.CublasCall([&](rocblas_handle handle) {
1211+
PADDLE_ENFORCE_GPU_SUCCESS(
1212+
paddle::platform::dynload::rocblas_dgemm_strided_batched(handle,
1213+
cuTransB,
1214+
cuTransA,
1215+
N,
1216+
M,
1217+
K,
1218+
&alpha,
1219+
B,
1220+
ldb,
1221+
strideB,
1222+
A,
1223+
lda,
1224+
strideA,
1225+
&beta,
1226+
C,
1227+
ldc,
1228+
strideC,
1229+
batchCount));
1230+
});
1231+
}
1232+
11311233
template <>
11321234
template <>
11331235
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,

0 commit comments

Comments
 (0)