Skip to content

Commit d518bdb

Browse files
rgerganovggerganov
authored andcommitted
ggml : add ggml_set_rows (ggml-org#14274)
* ggml : add ggml_set_rows Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using indices from 'c'. ref: ggml-org#8366 * use I64 for indices * ggml : add repeat impl for i64 * ggml : add ggml_is_contiguous_rows * ggml : ggml_set_rows support broadcast * ggml : ggml_set_rows support quantized dst ggml-ci * ggml : support GGML_TYPE_F32 ".from_float" trait * ggml : ggml_set_rows update comment + better index name * tests : add ggml_set_rows * metal : add ggml_set_rows implementation ggml-ci * ggml : simplify forward_dup_f32 * ggml : fix supports_op * tests : add comment to set_rows * ggml : leave the repeat_i64 for a separate PR ggml-ci * ggml : set_rows use std::min instead of MIN * ggml : better error message for set_rows unsupported type * metal : perform op->type check only once * tests : more consistent implementation + more tests ggml-ci --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent b3c74fe commit d518bdb

File tree

2 files changed

+3
-47
lines changed

2 files changed

+3
-47
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
138138
}
139139

140140
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
141-
#pragma METAL fp math_mode(safe)
142141
float min = FLT_MAX;
143142
float max = -FLT_MAX;
144143

@@ -204,7 +203,6 @@ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
204203
}
205204

206205
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
207-
#pragma METAL fp math_mode(safe)
208206
float max = src[0];
209207
float min = src[0];
210208

@@ -241,7 +239,6 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
241239
}
242240

243241
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
244-
#pragma METAL fp math_mode(safe)
245242
float amax = 0.0f; // absolute max
246243
float max = 0.0f;
247244

@@ -4736,49 +4733,8 @@ kernel void kernel_cpy_f32_q5_1(
47364733
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
47374734
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
47384735

4739-
float max = src[0];
4740-
float min = src[0];
4741-
4742-
for (int j = 1; j < QK5_1; j++) {
4743-
const float v = src[j];
4744-
min = v < min ? v : min;
4745-
max = v > max ? v : max;
4746-
}
4747-
4748-
const float d = (max - min) / 31;
4749-
const float id = d ? 1.0f/d : 0.0f;
4750-
4751-
dst_data[i00/QK5_1].d = d;
4752-
dst_data[i00/QK5_1].m = min;
4753-
4754-
uint32_t qh = 0;
4755-
for (int j = 0; j < QK5_1/2; ++j) {
4756-
const float x0 = (src[0 + j] - min)*id;
4757-
const float x1 = (src[QK5_1/2 + j] - min)*id;
4758-
4759-
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
4760-
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
4761-
4762-
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4763-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4764-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
4765-
}
4766-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4767-
for (int j = 0; j < 4; ++j) {
4768-
dst_data[i00/QK5_1].qh[j] = qh8[j];
4769-
}
4770-
}
4771-
}
4772-
4773-
static inline int best_index_int8(int n, constant float * val, float x) {
4774-
if (x <= val[0]) return 0;
4775-
if (x >= val[n-1]) return n-1;
4776-
int ml = 0, mu = n-1;
4777-
while (mu-ml > 1) {
4778-
int mav = (ml+mu)/2;
4779-
if (x < val[mav]) mu = mav; else ml = mav;
4736+
quantize_q5_1(src, dst_data[i00/QK5_1]);
47804737
}
4781-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
47824738
}
47834739

47844740
kernel void kernel_cpy_f32_iq4_nl(

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10031003
"GLU",
10041004
};
10051005

1006-
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
1006+
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
10071007

10081008
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10091009
"none",
@@ -1103,7 +1103,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11031103
"glu(x)",
11041104
};
11051105

1106-
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
1106+
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
11071107

11081108
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11091109

0 commit comments

Comments
 (0)