From ff61a49dd1a728a96e9a8434ed408a2a52d73119 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 3 Oct 2023 09:40:28 +0900 Subject: [PATCH] Allow changing epsilon parameter in RMS norm kernel (#1112) --- test/unit/util/rms_norm.cu | 9 +++++---- tools/util/include/cutlass/util/device_rmsnorm.h | 15 ++++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/test/unit/util/rms_norm.cu b/test/unit/util/rms_norm.cu index a3e6595dae..7897de5104 100644 --- a/test/unit/util/rms_norm.cu +++ b/test/unit/util/rms_norm.cu @@ -43,7 +43,8 @@ using Layout = cutlass::layout::RowMajor; void rmsnorm_host(cutlass::MatrixCoord tensor_size, cutlass::TensorRef output, cutlass::TensorRef input, - cutlass::TensorRef weight) { + cutlass::TensorRef weight, + float epsilon) { const int M = tensor_size.row(); const int N = tensor_size.column(); @@ -56,7 +57,7 @@ void rmsnorm_host(cutlass::MatrixCoord tensor_size, } float sq_mean = square_sum / (float)N; - float sqrt_var = cutlass::fast_sqrt(sq_mean + (float)1e-6); + float sqrt_var = cutlass::fast_sqrt(sq_mean + epsilon); for (int n = 0; n < N; ++n) { float inp = static_cast(input.at({m, n})); @@ -91,9 +92,9 @@ void run_test(int M, int N) { input.sync_device(); weight.sync_device(); - rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref()); + rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref(), (float)1e-5); cutlass::rmsnorm({M, N}, output.device_ref(), - input.device_ref(), weight.device_ref(), NULL); + input.device_ref(), weight.device_ref(), NULL, (float)1e-5); output.sync_host(); diff --git a/tools/util/include/cutlass/util/device_rmsnorm.h b/tools/util/include/cutlass/util/device_rmsnorm.h index a401db00bf..18b5c33f5a 100644 --- a/tools/util/include/cutlass/util/device_rmsnorm.h +++ b/tools/util/include/cutlass/util/device_rmsnorm.h @@ -43,7 +43,7 @@ namespace cutlass { __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, const float4 *weight, - const int m, const int n) { + const int m, const int n, float epsilon) { const int m_idx = blockIdx.x; const int tid = threadIdx.x; const int bdimx = blockDim.x; @@ -76,7 +76,7 @@ __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, blockReduceSum(local_sums); } if (threadIdx.x == 0) { - s_mean = rsqrtf(local_sums[0] / n + 1e-6); + s_mean = rsqrtf(local_sums[0] / n + epsilon); } __syncthreads(); @@ -117,7 +117,8 @@ template __global__ void rmsnorm_twoPassAlgo_e1(T* output, const T* input, const T* weight, - const int m, const int n) + const int m, const int n, + float epsilon) { const int m_idx = blockIdx.x; const int tid = threadIdx.x; @@ -139,7 +140,7 @@ __global__ void rmsnorm_twoPassAlgo_e1(T* output, blockReduceSum(local_sums); } if (threadIdx.x == 0) { - s_mean = rsqrtf(local_sums[0] / n + 1e-6); + s_mean = rsqrtf(local_sums[0] / n + epsilon); } __syncthreads(); @@ -155,7 +156,7 @@ void rmsnorm(cutlass::MatrixCoord tensor_size, TensorRef ref_output, TensorRef ref_input, TensorRef ref_weight, - cudaStream_t stream){ + cudaStream_t stream, float epsilon = 1e-5){ const int m = tensor_size.row(); const int n = tensor_size.column(); T* output = ref_output.data(); @@ -167,12 +168,12 @@ void rmsnorm(cutlass::MatrixCoord tensor_size, dim3 block(min(1024, (n / 8 + 31) / 32 * 32)); rmsnorm_twoPassAlgo_e8<<>>( - (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n); + (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon); } else { dim3 block(min(1024, ((n + 31)/32 + 31)/32*32)); rmsnorm_twoPassAlgo_e1<<>>( - output, input, weight, m, n); + output, input, weight, m, n, epsilon); } auto result = cudaGetLastError();