Skip to content

Commit 8cba478

Browse files
ikawrakowIwan Kawrakow
andauthored
iqk_mul_mat: better srategy when nrc_y not divisible by ny (#71)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent fd20638 commit 8cba478

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,39 @@ struct MulMat {
107107
while (!funcs[ny-1] && ny > 0) --ny;
108108
int n_step = (nrc_y - info.cur_y)/ny;
109109
if (n_step > 0) {
110-
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
111-
auto this_info = info;
112-
this_info.s += ix;
113-
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
114-
for (int iy = 0; iy < n_step; ++iy) {
115-
funcs[ny-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
116-
this_info.cur_y += ny;
110+
if (n_step*ny != nrc_y) {
111+
++n_step;
112+
int ny1 = nrc_y/n_step;
113+
int ny2 = ny1 + 1;
114+
int my1 = n_step*ny2 - nrc_y;
115+
int my2 = n_step - my1;
116+
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
117+
auto this_info = info;
118+
this_info.s += ix;
119+
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
120+
for (int iy = 0; iy < my1; ++iy) {
121+
funcs[ny1-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
122+
this_info.cur_y += ny1;
123+
}
124+
for (int iy = 0; iy < my2; ++iy) {
125+
funcs[ny2-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
126+
this_info.cur_y += ny2;
127+
}
128+
}
129+
info.cur_y += nrc_y;
130+
}
131+
else {
132+
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
133+
auto this_info = info;
134+
this_info.s += ix;
135+
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
136+
for (int iy = 0; iy < n_step; ++iy) {
137+
funcs[ny-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
138+
this_info.cur_y += ny;
139+
}
117140
}
141+
info.cur_y += ny * n_step;
118142
}
119-
info.cur_y += ny * n_step;
120143
}
121144
int n_left = nrc_y - info.cur_y;
122145
if (n_left > 0) {

0 commit comments

Comments
 (0)