Skip to content

Commit 77d586f

Browse files
committed
sycl : try to fix after IQ1_S changes
1 parent caa106d commit 77d586f

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

ggml-sycl.cpp

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3514,8 +3514,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N
35143514
#define QI1_S (QK_K / (4*QR1_S))
35153515
typedef struct {
35163516
sycl::half d;
3517-
uint8_t qs[QK_K/8];
3518-
uint8_t scales[QK_K/16];
3517+
uint8_t qs[QK_K/8];
3518+
uint16_t qh[QK_K/32];
35193519
} block_iq1_s;
35203520
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
35213521

@@ -4894,7 +4894,6 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
48944894
const uint64_t *iq1s_grid,
48954895
const uint8_t *ksigns_iq2xs,
48964896
const uint8_t *kmask_iq2xs) {
4897-
48984897
const int i = item_ct1.get_group(2);
48994898
const block_iq1_s * x = (const block_iq1_s *) vx;
49004899

@@ -4903,11 +4902,15 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
49034902
const int il = tid/8; // 0...3
49044903
const int ib = tid%8; // 0...7
49054904
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4906-
const int i8 = 4*ib+il;
4907-
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
4908-
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
4909-
const float d = (float)x[i].d * (2*(h & 7) + 1);
4910-
for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
4905+
const uint8_t * qs = x[i].qs + 8*ib;
4906+
const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]);
4907+
const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]);
4908+
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
4909+
const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
4910+
for (int j = 0; j < 4; ++j) {
4911+
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4912+
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4913+
}
49114914
#else
49124915
assert(false);
49134916
#endif
@@ -7808,23 +7811,22 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
78087811
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
78097812

78107813
const int ib32 = iqs;
7811-
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
7812-
const uint8_t h1 = bq1->scales[2*ib32+0];
7813-
const uint8_t h2 = bq1->scales[2*ib32+1];
7814-
const int * q8 = (const int *)bq8_1[ib32].qs;
7815-
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
7816-
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
7817-
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
7818-
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
7819-
for (int j = 0; j < 2; ++j) {
7820-
sumi1 = dpct::dp4a(q8[j+0], grid1[j], sumi1);
7821-
sumi2 = dpct::dp4a(q8[j+2], grid2[j], sumi2);
7822-
sumi3 = dpct::dp4a(q8[j+4], grid3[j], sumi3);
7823-
sumi4 = dpct::dp4a(q8[j+6], grid4[j], sumi4);
7824-
}
7825-
const float d = (float)bq1->d * bq8_1[ib32].ds[0];
7826-
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
7827-
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
7814+
const uint8_t * qs = bq1->qs + 4*ib32;
7815+
const int8_t * q8 = bq8_1[ib32].qs;
7816+
int sumi = 0;
7817+
for (int l = 0; l < 4; ++l) {
7818+
const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]);
7819+
const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8));
7820+
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7821+
grid[0] ^ signs[0], signs[0], std::minus<>());
7822+
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7823+
grid[1] ^ signs[1], signs[1], std::minus<>());
7824+
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
7825+
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
7826+
q8 += 8;
7827+
}
7828+
const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f;
7829+
return d * sumi;
78287830
#else
78297831
assert(false);
78307832
return 0.f;

0 commit comments

Comments
 (0)