@@ -321,6 +321,7 @@ struct ggml_backend_opencl_context {
321321 cl_program program_upscale;
322322 cl_program program_concat;
323323 cl_program program_tsembd;
324+ cl_program program_mul_mv_id_q4_0_f32_8x_flat;
324325
325326 cl_kernel kernel_add, kernel_add_row;
326327 cl_kernel kernel_mul, kernel_mul_row;
@@ -366,6 +367,7 @@ struct ggml_backend_opencl_context {
366367 cl_kernel kernel_concat_f32_contiguous;
367368 cl_kernel kernel_concat_f32_non_contiguous;
368369 cl_kernel kernel_timestep_embedding;
370+ cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
369371
370372#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
371373 // Transpose kernels
@@ -1112,7 +1114,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11121114 GGML_LOG_CONT (" ." );
11131115 }
11141116
1115- // repeat
1117+ // repeat
11161118 {
11171119#ifdef GGML_OPENCL_EMBED_KERNELS
11181120 const std::string kernel_src {
@@ -1256,6 +1258,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
12561258 }
12571259 }
12581260
1261+ // mul_mv_id_q4_0_f32_8x_flat
1262+ {
1263+ #ifdef GGML_OPENCL_EMBED_KERNELS
1264+ const std::string kernel_src {
1265+ #include " mul_mv_id_q4_0_f32_8x_flat.cl.h"
1266+ };
1267+ #else
1268+ const std::string kernel_src = read_file (" mul_mv_id_q4_0_f32_8x_flat.cl" );
1269+ #endif
1270+ backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat =
1271+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1272+
1273+ CL_CHECK ((backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat = clCreateKernel (backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat , " kernel_mul_mv_id_q4_0_f32_8x_flat" , &err), err));
1274+ GGML_LOG_CONT (" ." );
1275+ }
1276+
12591277 // Adreno kernels
12601278#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
12611279 // transpose
@@ -2178,6 +2196,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
21782196 return op->src [1 ]->type == GGML_TYPE_F32 && ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]);
21792197 }
21802198 return false ;
2199+ case GGML_OP_MUL_MAT_ID:
2200+ if (op->src [0 ]->type == GGML_TYPE_Q4_0) {
2201+ if (op->src [1 ]->type == GGML_TYPE_F32) {
2202+ return ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]);
2203+ }
2204+ }
2205+ return false ;
21812206 case GGML_OP_RESHAPE:
21822207 case GGML_OP_VIEW:
21832208 case GGML_OP_PERMUTE:
@@ -5536,6 +5561,136 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
55365561 }
55375562}
55385563
5564+ static void ggml_cl_mul_mat_id (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5565+ GGML_ASSERT (src0);
5566+ GGML_ASSERT (src0->extra );
5567+ GGML_ASSERT (src1);
5568+ GGML_ASSERT (src1->extra );
5569+ GGML_ASSERT (dst);
5570+ GGML_ASSERT (dst->extra );
5571+
5572+ const ggml_tensor * src2 = dst->src [2 ];
5573+ GGML_ASSERT (src2);
5574+ GGML_ASSERT (src2->extra );
5575+
5576+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5577+ cl_command_queue queue = backend_ctx->queue ;
5578+
5579+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
5580+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra ;
5581+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5582+
5583+ cl_ulong offset1 = extra1->offset + src1->view_offs ;
5584+ cl_ulong offset2 = extra2->offset + src2->view_offs ;
5585+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
5586+
5587+ #ifdef GGML_OPENCL_SOA_Q
5588+ ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra ;
5589+ #endif
5590+
5591+ const int ne00 = src0->ne [0 ];
5592+ const int ne01 = src0->ne [1 ];
5593+ const int ne02 = src0->ne [2 ];
5594+ const int ne03 = src0->ne [3 ];
5595+
5596+ const cl_ulong nb00 = src0->nb [0 ];
5597+ const cl_ulong nb02 = src0->nb [2 ];
5598+
5599+ const int ne10 = src1->ne [0 ];
5600+ const int ne11 = src1->ne [1 ];
5601+ const int ne12 = src1->ne [2 ];
5602+ const int ne13 = src1->ne [3 ];
5603+
5604+ const cl_ulong nb11 = src1->nb [1 ];
5605+ const cl_ulong nb12 = src1->nb [2 ];
5606+
5607+ const int ne20 = src2->ne [0 ];
5608+ const int ne21 = src2->ne [1 ];
5609+
5610+ const cl_ulong nb21 = src2->nb [1 ];
5611+
5612+ const int ne0 = dst->ne [0 ];
5613+ const int ne1 = dst->ne [1 ];
5614+
5615+ const int r2 = ne12/ne02;
5616+ const int r3 = ne13/ne03;
5617+ const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows
5618+
5619+ GGML_ASSERT (ne00 == ne10);
5620+
5621+ int sgs = 32 ; // subgroup size
5622+ int nsg = 1 ; // number of subgroups
5623+ int nrows = 1 ; // number of row in src1
5624+ int ndst = 4 ; // number of values produced by each subgroup
5625+
5626+ cl_kernel kernel;
5627+
5628+ // subgroup mat vec
5629+ switch (src0->type ) {
5630+ case GGML_TYPE_Q4_0: {
5631+ kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat ;
5632+
5633+ if (backend_ctx->gpu_family == INTEL) {
5634+ sgs = 16 ;
5635+ nsg = 1 ;
5636+ ndst = 8 ;
5637+ } else if (backend_ctx->gpu_family == ADRENO) {
5638+ sgs = 64 ;
5639+ nsg = 1 ;
5640+ ndst = 8 ;
5641+ } else {
5642+ GGML_ASSERT (false && " TODO: Unknown GPU" );
5643+ }
5644+
5645+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0_q4_0->q ));
5646+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_mem), &extra0_q4_0->d ));
5647+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
5648+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
5649+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extra2->data_device ));
5650+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
5651+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
5652+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
5653+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00));
5654+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne01));
5655+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne02));
5656+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb00));
5657+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
5658+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne10));
5659+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne11));
5660+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne12));
5661+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb11));
5662+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb12));
5663+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &ne20));
5664+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (int ), &ne21));
5665+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb21));
5666+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (int ), &ne0));
5667+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne1));
5668+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &r2));
5669+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &r3));
5670+
5671+ break ;
5672+ }
5673+ default :
5674+ GGML_ASSERT (false && " not implemented" );;
5675+ }
5676+
5677+ int _ne1 = 1 ;
5678+ int ne123 = dst_rows;
5679+
5680+ size_t global_work_size[] = {(size_t )(ne01+ndst*nsg-1 )/(ndst*nsg)*sgs, (size_t )(_ne1+nrows-1 )/nrows*nsg, (size_t )ne123};
5681+ size_t local_work_size[] = {(size_t )sgs, (size_t )nsg, 1 };
5682+
5683+ #ifdef GGML_OPENCL_PROFILING
5684+ cl_event evt;
5685+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
5686+
5687+ g_profiling_info.emplace_back ();
5688+ populateProfilingInfo (g_profiling_info.back (), evt, kernel, global_work_size, local_work_size, dst);
5689+ #else
5690+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , NULL ));
5691+ #endif
5692+ }
5693+
55395694static void ggml_cl_scale (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
55405695 GGML_ASSERT (src0);
55415696 GGML_ASSERT (src0->extra );
@@ -6444,6 +6599,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
64446599 }
64456600 func = ggml_cl_mul_mat;
64466601 break ;
6602+ case GGML_OP_MUL_MAT_ID:
6603+ if (!any_on_device) {
6604+ return false ;
6605+ }
6606+ func = ggml_cl_mul_mat_id;
6607+ break ;
64476608 case GGML_OP_SCALE:
64486609 if (!any_on_device) {
64496610 return false ;
0 commit comments