@@ -880,7 +880,7 @@ struct vk_op_conv2d_push_constants {
880880 uint32_t Cout;
881881 uint32_t Cin;
882882 uint32_t N;
883-
883+
884884 uint32_t KW;
885885 uint32_t KH;
886886 uint32_t W;
@@ -1041,7 +1041,7 @@ class vk_perf_logger {
10411041 }
10421042
10431043 timings.clear();
1044- flops.clear();
1044+ flops.clear();
10451045 }
10461046
10471047 void log_timing(const ggml_tensor * node, uint64_t time) {
@@ -1082,7 +1082,7 @@ class vk_perf_logger {
10821082 flops[name].push_back(n_flops);
10831083 timings[name].push_back(time);
10841084 return;
1085- }
1085+ }
10861086 timings[ggml_op_name(node->op)].push_back(time);
10871087 }
10881088private:
@@ -2190,6 +2190,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21902190 }
21912191 compile_count++;
21922192 }
2193+
21932194 compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
21942195 parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
21952196 };
@@ -3037,14 +3038,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
30373038 uint32_t conv2d_WG_SIZE = 256;
30383039 uint32_t conv2d_BS_K = 128;
30393040 uint32_t conv2d_BS_CRS = 16;
3041+ // Enables subgroup ops for preventing the re-calculation of indices.
3042+ uint32_t use_collectives = 0;
3043+ // CRS block size should be capped at sugroup size for correctness when shuffle is used.
3044+ if(device->subgroup_shuffle){
3045+ use_collectives = 1;
3046+ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3047+ }
30403048 uint32_t conv2d_BS_NPQ = 128;
30413049 uint32_t conv2d_TS_K = 8;
30423050 uint32_t conv2d_shmem_req = (conv2d_BS_K*(conv2d_BS_CRS+1) + conv2d_BS_CRS*(conv2d_BS_NPQ+1))*sizeof(float);
30433051 if(device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req){
30443052 conv2d_BS_CRS = 8;
3045- conv2d_TS_K = 8;
3046- }
3047- ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K}, 1);
3053+ if(device->subgroup_shuffle){
3054+ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3055+ }
3056+ }
3057+ if(device->subgroup_shuffle){
3058+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true);
3059+ }else{
3060+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true);
3061+ }
30483062
30493063 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
30503064 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -6895,11 +6909,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
68956909 }
68966910 return nullptr;
68976911 case GGML_OP_CONV_2D:
6898- if (src0->type == GGML_TYPE_F32 &&
6899- src1->type == GGML_TYPE_F32 &&
6900- dst->type == GGML_TYPE_F32 &&
6901- ggml_is_contiguous(src0) &&
6902- ggml_is_contiguous(src1) &&
6912+ if (src0->type == GGML_TYPE_F32 &&
6913+ src1->type == GGML_TYPE_F32 &&
6914+ dst->type == GGML_TYPE_F32 &&
6915+ ggml_is_contiguous(src0) &&
6916+ ggml_is_contiguous(src1) &&
69036917 ggml_is_contiguous(dst)) {
69046918 return ctx->device->pipeline_conv2d_f32;
69056919 }
@@ -7231,7 +7245,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
72317245 // src0 - kernel: [KW, KH, Cin, Cout]
72327246 // src1 - input: [W, H, Cin, N]
72337247 // dst - result: [OW, OH, Cout, N]
7234-
7248+
72357249 // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
72367250 auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
72377251 return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
@@ -7246,9 +7260,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
72467260 int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
72477261 int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
72487262 int64_t NPQ = N*OW*OH;
7249-
7263+
72507264 // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7251- elements = {static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1};
7265+ elements = {static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1};
72527266 } break;
72537267 case GGML_OP_ADD:
72547268 case GGML_OP_SUB:
@@ -8131,14 +8145,14 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
81318145 p.Cout = static_cast<uint32_t>(ne03);
81328146 p.Cin = static_cast<uint32_t>(ne02);
81338147 p.N = static_cast<uint32_t>(ne13);
8134-
8148+
81358149 p.KW = static_cast<uint32_t>(ne00);
81368150 p.KH = static_cast<uint32_t>(ne01);
81378151 p.W = static_cast<uint32_t>(ne10);
81388152 p.H = static_cast<uint32_t>(ne11);
81398153 p.OW = static_cast<uint32_t>(ne0);
81408154 p.OH = static_cast<uint32_t>(ne1);
8141-
8155+
81428156 p.s0 = static_cast<uint32_t>(dst->op_params[0]);
81438157 p.s1 = static_cast<uint32_t>(dst->op_params[1]);
81448158 p.p0 = static_cast<uint32_t>(dst->op_params[2]);
@@ -8162,7 +8176,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
81628176 GGML_ASSERT(ne02 == ne12);
81638177
81648178 ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
8165-
8179+
81668180}
81678181
81688182static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -10805,11 +10819,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1080510819 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1080610820 case GGML_OP_CONV_2D:
1080710821 // Channel-contiguous format is not supported yet.
10808- return (op->src[0]->type == GGML_TYPE_F32 &&
10809- op->src[1]->type == GGML_TYPE_F32 &&
10810- op->type == GGML_TYPE_F32 &&
10811- ggml_is_contiguous(op->src[0]) &&
10812- ggml_is_contiguous(op->src[1]) &&
10822+ return (op->src[0]->type == GGML_TYPE_F32 &&
10823+ op->src[1]->type == GGML_TYPE_F32 &&
10824+ op->type == GGML_TYPE_F32 &&
10825+ ggml_is_contiguous(op->src[0]) &&
10826+ ggml_is_contiguous(op->src[1]) &&
1081310827 ggml_is_contiguous(op));
1081410828 default:
1081510829 return false;
0 commit comments