@@ -251,6 +251,8 @@ struct vk_device_struct {
251
251
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
252
252
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
253
253
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
254
+ vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
255
+ vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
254
256
vk_pipeline pipeline_argsort_f32;
255
257
vk_pipeline pipeline_sum_rows_f32;
256
258
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -494,6 +496,10 @@ struct vk_op_rope_push_constants {
494
496
float corr_dims[2];
495
497
float theta_scale;
496
498
uint32_t has_ff;
499
+ uint32_t ne02;
500
+ uint32_t s1;
501
+ uint32_t s2;
502
+ int32_t sections[4];
497
503
};
498
504
499
505
struct vk_op_soft_max_push_constants {
@@ -2180,13 +2186,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
2180
2186
2181
2187
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2182
2188
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2189
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2190
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2183
2191
2184
2192
if (device->float_controls_rte_fp16) {
2185
2193
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2186
2194
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2195
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2196
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2187
2197
} else {
2188
2198
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2189
2199
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2200
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2201
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2190
2202
}
2191
2203
2192
2204
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
@@ -5307,6 +5319,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5307
5319
{
5308
5320
const int mode = ((const int32_t *) dst->op_params)[2];
5309
5321
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5322
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5323
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5310
5324
5311
5325
if (is_neox) {
5312
5326
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -5315,6 +5329,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5315
5329
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5316
5330
return ctx->device->pipeline_rope_neox_f16;
5317
5331
}
5332
+ } else if (is_mrope && !is_vision) {
5333
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5334
+ return ctx->device->pipeline_rope_multi_f32;
5335
+ }
5336
+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5337
+ return ctx->device->pipeline_rope_multi_f16;
5338
+ }
5339
+ } else if (is_vision) {
5340
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5341
+ return ctx->device->pipeline_rope_vision_f32;
5342
+ }
5343
+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5344
+ return ctx->device->pipeline_rope_vision_f16;
5345
+ }
5318
5346
} else {
5319
5347
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5320
5348
return ctx->device->pipeline_rope_norm_f32;
@@ -5385,6 +5413,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5385
5413
case GGML_OP_CLAMP:
5386
5414
case GGML_OP_PAD:
5387
5415
case GGML_OP_REPEAT:
5416
+ case GGML_OP_ROPE:
5388
5417
return true;
5389
5418
default:
5390
5419
return false;
@@ -6149,7 +6178,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6149
6178
6150
6179
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6151
6180
const int n_dims = ((int32_t *) dst->op_params)[1];
6152
- // const int mode = ((int32_t *) dst->op_params)[2];
6181
+ const int mode = ((int32_t *) dst->op_params)[2];
6153
6182
// const int n_ctx = ((int32_t *) dst->op_params)[3];
6154
6183
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
6155
6184
const float freq_base = ((float *) dst->op_params)[5];
@@ -6158,16 +6187,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
6158
6187
const float attn_factor = ((float *) dst->op_params)[8];
6159
6188
const float beta_fast = ((float *) dst->op_params)[9];
6160
6189
const float beta_slow = ((float *) dst->op_params)[10];
6190
+ int sections[4] {};
6191
+ if (mode & GGML_ROPE_TYPE_MROPE) {
6192
+ memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
6193
+ }
6161
6194
6162
6195
float corr_dims[2];
6163
6196
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
6164
6197
6165
6198
const float theta_scale = powf(freq_base, -2.0f/n_dims);
6166
6199
6200
+ uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
6201
+ uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
6202
+
6167
6203
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
6168
6204
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6169
6205
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6170
- src2 != nullptr,
6206
+ src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6207
+ sections[0], sections[1], sections[2], sections[3],
6171
6208
}, dryrun);
6172
6209
}
6173
6210
@@ -8264,16 +8301,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8264
8301
case GGML_OP_REPEAT:
8265
8302
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
8266
8303
case GGML_OP_ROPE:
8267
- {
8268
- const int mode = ((const int32_t *) op->op_params)[2];
8269
- if (mode & GGML_ROPE_TYPE_MROPE) {
8270
- return false;
8271
- }
8272
- if (mode & GGML_ROPE_TYPE_VISION) {
8273
- return false;
8274
- }
8275
- return ggml_is_contiguous(op->src[0]);
8276
- }
8277
8304
case GGML_OP_NONE:
8278
8305
case GGML_OP_RESHAPE:
8279
8306
case GGML_OP_VIEW:
@@ -8831,7 +8858,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8831
8858
const float attn_factor = ((float *) tensor->op_params)[8];
8832
8859
const float beta_fast = ((float *) tensor->op_params)[9];
8833
8860
const float beta_slow = ((float *) tensor->op_params)[10];
8834
- tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8861
+ if (mode & GGML_ROPE_TYPE_MROPE) {
8862
+ int32_t *sections = ((int32_t *) tensor->op_params) + 11;
8863
+ tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8864
+ } else {
8865
+ tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8866
+ }
8835
8867
} else if (tensor->op == GGML_OP_UNARY) {
8836
8868
switch (ggml_get_unary_op(tensor)) {
8837
8869
case GGML_UNARY_OP_SILU:
0 commit comments