Skip to content

Commit

Permalink
Allow changing epsilon parameter in RMS norm kernel (NVIDIA#1112)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Oct 3, 2023
1 parent 26986bb commit ff61a49
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
9 changes: 5 additions & 4 deletions test/unit/util/rms_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ using Layout = cutlass::layout::RowMajor;
void rmsnorm_host(cutlass::MatrixCoord tensor_size,
cutlass::TensorRef<ElementType, Layout> output,
cutlass::TensorRef<ElementType, Layout> input,
cutlass::TensorRef<ElementType, Layout> weight) {
cutlass::TensorRef<ElementType, Layout> weight,
float epsilon) {
const int M = tensor_size.row();
const int N = tensor_size.column();

Expand All @@ -56,7 +57,7 @@ void rmsnorm_host(cutlass::MatrixCoord tensor_size,
}

float sq_mean = square_sum / (float)N;
float sqrt_var = cutlass::fast_sqrt(sq_mean + (float)1e-6);
float sqrt_var = cutlass::fast_sqrt(sq_mean + epsilon);

for (int n = 0; n < N; ++n) {
float inp = static_cast<float>(input.at({m, n}));
Expand Down Expand Up @@ -91,9 +92,9 @@ void run_test(int M, int N) {
input.sync_device();
weight.sync_device();

rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref());
rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref(), (float)1e-5);
cutlass::rmsnorm({M, N}, output.device_ref(),
input.device_ref(), weight.device_ref(), NULL);
input.device_ref(), weight.device_ref(), NULL, (float)1e-5);

output.sync_host();

Expand Down
15 changes: 8 additions & 7 deletions tools/util/include/cutlass/util/device_rmsnorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace cutlass {

__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
const float4 *weight,
const int m, const int n) {
const int m, const int n, float epsilon) {
const int m_idx = blockIdx.x;
const int tid = threadIdx.x;
const int bdimx = blockDim.x;
Expand Down Expand Up @@ -76,7 +76,7 @@ __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
blockReduceSum<float, 1>(local_sums);
}
if (threadIdx.x == 0) {
s_mean = rsqrtf(local_sums[0] / n + 1e-6);
s_mean = rsqrtf(local_sums[0] / n + epsilon);
}
__syncthreads();

Expand Down Expand Up @@ -117,7 +117,8 @@ template<typename T>
__global__ void rmsnorm_twoPassAlgo_e1(T* output,
const T* input,
const T* weight,
const int m, const int n)
const int m, const int n,
float epsilon)
{
const int m_idx = blockIdx.x;
const int tid = threadIdx.x;
Expand All @@ -139,7 +140,7 @@ __global__ void rmsnorm_twoPassAlgo_e1(T* output,
blockReduceSum<float, 1>(local_sums);
}
if (threadIdx.x == 0) {
s_mean = rsqrtf(local_sums[0] / n + 1e-6);
s_mean = rsqrtf(local_sums[0] / n + epsilon);
}
__syncthreads();

Expand All @@ -155,7 +156,7 @@ void rmsnorm(cutlass::MatrixCoord tensor_size,
TensorRef<T, layout::RowMajor> ref_output,
TensorRef<T, layout::RowMajor> ref_input,
TensorRef<T, layout::RowMajor> ref_weight,
cudaStream_t stream){
cudaStream_t stream, float epsilon = 1e-5){
const int m = tensor_size.row();
const int n = tensor_size.column();
T* output = ref_output.data();
Expand All @@ -167,12 +168,12 @@ void rmsnorm(cutlass::MatrixCoord tensor_size,
dim3 block(min(1024, (n / 8 + 31) / 32 * 32));

rmsnorm_twoPassAlgo_e8<<<grid, block, 0, stream>>>(
(float4 *)output, (const float4 *)input, (const float4 *)weight, m, n);
(float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon);
} else {
dim3 block(min(1024, ((n + 31)/32 + 31)/32*32));

rmsnorm_twoPassAlgo_e1<<<grid, block, 0, stream>>>(
output, input, weight, m, n);
output, input, weight, m, n, epsilon);
}

auto result = cudaGetLastError();
Expand Down

0 comments on commit ff61a49

Please sign in to comment.