11#include " norm.hpp"
2+ #include " ggml-sycl/common.hpp"
23
34static void norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
45 const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
56
67 const int nrows = item_ct1.get_group_range (2 );
78 const int nchannels = item_ct1.get_group_range (1 );
9+ const int nsamples = item_ct1.get_group_range (0 );
10+
811 const int nthreads = item_ct1.get_local_range (2 );
912 const int sample = item_ct1.get_group (0 );
1013 const int channel = item_ct1.get_group (1 );
@@ -13,8 +16,11 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
1316 const int tid = item_ct1.get_local_id (2 );
1417 const int nwarps = nthreads / WARP_SIZE;
1518
16- x += sample * stride_sample + channel * stride_channel + row * stride_row;
17- dst += ((sample * nchannels + channel) * nrows + row) * ncols;
19+ const auto strided_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, {stride_sample, stride_channel, stride_row}, {sample, channel, row});
20+ const auto packed_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, {nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21+
22+ x += strided_offset;
23+ dst += packed_offset;
1824
1925 sycl::float2 mean_var = sycl::float2 (0 .f , 0 .f );
2026
@@ -144,16 +150,22 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
144150
145151 const int nrows = item_ct1.get_group_range (2 );
146152 const int nchannels = item_ct1.get_group_range (1 );
153+ const int nsamples = item_ct1.get_group_range (0 );
154+
147155 const int sample = item_ct1.get_group (0 );
148156 const int channel = item_ct1.get_group (1 );
149157 const int row = item_ct1.get_group (2 );
158+
150159 const int nthreads = item_ct1.get_local_range (2 );
151160
152161 const int tid = item_ct1.get_local_id (2 );
153162 const int nwarps = nthreads / WARP_SIZE;
154163
155- x += sample*stride_sample + channel*stride_channel + row*stride_row;
156- dst += ((sample*nchannels + channel)*nrows + row)*ncols;
164+ const auto strided_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, {stride_sample, stride_channel, stride_row}, {sample, channel, row});
165+ const auto packed_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, {nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
166+
167+ x += strided_offset;
168+ dst += packed_offset;
157169
158170
159171 float tmp = 0 .0f ; // partial sum for thread in warp
0 commit comments