@@ -6,7 +6,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
66
77 const int nrows = item_ct1.get_group_range (2 );
88 const int nchannels = item_ct1.get_group_range (1 );
9- const int nsamples = item_ct1.get_group_range (0 );
109
1110 const int nthreads = item_ct1.get_local_range (2 );
1211 const int sample = item_ct1.get_group (0 );
@@ -16,8 +15,8 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
1615 const int tid = item_ct1.get_local_id (2 );
1716 const int nwarps = nthreads / WARP_SIZE;
1817
19- const auto strided_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, { stride_sample, stride_channel, stride_row}, {sample, channel, row});
20- const auto packed_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, { nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
18+ const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row}, {sample, channel, row});
19+ const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
2120
2221 x += strided_offset;
2322 dst += packed_offset;
@@ -150,7 +149,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
150149
151150 const int nrows = item_ct1.get_group_range (2 );
152151 const int nchannels = item_ct1.get_group_range (1 );
153- const int nsamples = item_ct1.get_group_range (0 );
154152
155153 const int sample = item_ct1.get_group (0 );
156154 const int channel = item_ct1.get_group (1 );
@@ -161,8 +159,8 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
161159 const int tid = item_ct1.get_local_id (2 );
162160 const int nwarps = nthreads / WARP_SIZE;
163161
164- const auto strided_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, { stride_sample, stride_channel, stride_row}, {sample, channel, row});
165- const auto packed_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, { nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
162+ const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row}, {sample, channel, row});
163+ const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
166164
167165 x += strided_offset;
168166 dst += packed_offset;
0 commit comments