Skip to content

Commit 7f9b659

Browse files
committed
* Apple/Win32 compile errors fixed
* Subgroup size used to determine tile size -> fixes llvmpipe errors.
1 parent 0715985 commit 7f9b659

File tree

5 files changed

+120
-77
lines changed

5 files changed

+120
-77
lines changed

ggml/src/ggml-backend.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,12 +1883,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
18831883
}
18841884

18851885
bool ggml_backend_compare_graph_backend_node(
1886-
ggml_backend_t backend1,
1887-
ggml_backend_t backend2,
1888-
struct ggml_cgraph * graph1,
1889-
struct ggml_cgraph * graph2,
1886+
ggml_backend_t backend1,
1887+
ggml_backend_t backend2,
1888+
struct ggml_cgraph * graph1,
1889+
struct ggml_cgraph * graph2,
18901890
ggml_backend_eval_callback callback, void * user_data, char* op_name_out_1, char* op_name_out_2) {
1891-
1891+
18921892
ggml_tensor * out1 = NULL;
18931893
ggml_tensor * out2 = NULL;
18941894

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ struct vk_op_conv2d_push_constants {
880880
uint32_t Cout;
881881
uint32_t Cin;
882882
uint32_t N;
883-
883+
884884
uint32_t KW;
885885
uint32_t KH;
886886
uint32_t W;
@@ -1041,7 +1041,7 @@ class vk_perf_logger {
10411041
}
10421042

10431043
timings.clear();
1044-
flops.clear();
1044+
flops.clear();
10451045
}
10461046

10471047
void log_timing(const ggml_tensor * node, uint64_t time) {
@@ -1082,7 +1082,7 @@ class vk_perf_logger {
10821082
flops[name].push_back(n_flops);
10831083
timings[name].push_back(time);
10841084
return;
1085-
}
1085+
}
10861086
timings[ggml_op_name(node->op)].push_back(time);
10871087
}
10881088
private:
@@ -2190,6 +2190,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21902190
}
21912191
compile_count++;
21922192
}
2193+
21932194
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
21942195
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
21952196
};
@@ -3037,14 +3038,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
30373038
uint32_t conv2d_WG_SIZE = 256;
30383039
uint32_t conv2d_BS_K = 128;
30393040
uint32_t conv2d_BS_CRS = 16;
3041+
// Enables subgroup ops for preventing the re-calculation of indices.
3042+
uint32_t use_collectives = 0;
3043+
// CRS block size should be capped at sugroup size for correctness when shuffle is used.
3044+
if(device->subgroup_shuffle){
3045+
use_collectives = 1;
3046+
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3047+
}
30403048
uint32_t conv2d_BS_NPQ = 128;
30413049
uint32_t conv2d_TS_K = 8;
30423050
uint32_t conv2d_shmem_req = (conv2d_BS_K*(conv2d_BS_CRS+1) + conv2d_BS_CRS*(conv2d_BS_NPQ+1))*sizeof(float);
30433051
if(device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req){
30443052
conv2d_BS_CRS = 8;
3045-
conv2d_TS_K = 8;
3046-
}
3047-
ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K}, 1);
3053+
if(device->subgroup_shuffle){
3054+
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3055+
}
3056+
}
3057+
if(device->subgroup_shuffle){
3058+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true);
3059+
}else{
3060+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true);
3061+
}
30483062

30493063
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
30503064
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -6895,11 +6909,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
68956909
}
68966910
return nullptr;
68976911
case GGML_OP_CONV_2D:
6898-
if (src0->type == GGML_TYPE_F32 &&
6899-
src1->type == GGML_TYPE_F32 &&
6900-
dst->type == GGML_TYPE_F32 &&
6901-
ggml_is_contiguous(src0) &&
6902-
ggml_is_contiguous(src1) &&
6912+
if (src0->type == GGML_TYPE_F32 &&
6913+
src1->type == GGML_TYPE_F32 &&
6914+
dst->type == GGML_TYPE_F32 &&
6915+
ggml_is_contiguous(src0) &&
6916+
ggml_is_contiguous(src1) &&
69036917
ggml_is_contiguous(dst)) {
69046918
return ctx->device->pipeline_conv2d_f32;
69056919
}
@@ -7231,7 +7245,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
72317245
// src0 - kernel: [KW, KH, Cin, Cout]
72327246
// src1 - input: [W, H, Cin, N]
72337247
// dst - result: [OW, OH, Cout, N]
7234-
7248+
72357249
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
72367250
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
72377251
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
@@ -7246,9 +7260,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
72467260
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
72477261
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
72487262
int64_t NPQ = N*OW*OH;
7249-
7263+
72507264
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7251-
elements = {static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1};
7265+
elements = {static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1};
72527266
} break;
72537267
case GGML_OP_ADD:
72547268
case GGML_OP_SUB:
@@ -8131,14 +8145,14 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
81318145
p.Cout = static_cast<uint32_t>(ne03);
81328146
p.Cin = static_cast<uint32_t>(ne02);
81338147
p.N = static_cast<uint32_t>(ne13);
8134-
8148+
81358149
p.KW = static_cast<uint32_t>(ne00);
81368150
p.KH = static_cast<uint32_t>(ne01);
81378151
p.W = static_cast<uint32_t>(ne10);
81388152
p.H = static_cast<uint32_t>(ne11);
81398153
p.OW = static_cast<uint32_t>(ne0);
81408154
p.OH = static_cast<uint32_t>(ne1);
8141-
8155+
81428156
p.s0 = static_cast<uint32_t>(dst->op_params[0]);
81438157
p.s1 = static_cast<uint32_t>(dst->op_params[1]);
81448158
p.p0 = static_cast<uint32_t>(dst->op_params[2]);
@@ -8162,7 +8176,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
81628176
GGML_ASSERT(ne02 == ne12);
81638177

81648178
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
8165-
8179+
81668180
}
81678181

81688182
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -10805,11 +10819,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1080510819
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1080610820
case GGML_OP_CONV_2D:
1080710821
// Channel-contiguous format is not supported yet.
10808-
return (op->src[0]->type == GGML_TYPE_F32 &&
10809-
op->src[1]->type == GGML_TYPE_F32 &&
10810-
op->type == GGML_TYPE_F32 &&
10811-
ggml_is_contiguous(op->src[0]) &&
10812-
ggml_is_contiguous(op->src[1]) &&
10822+
return (op->src[0]->type == GGML_TYPE_F32 &&
10823+
op->src[1]->type == GGML_TYPE_F32 &&
10824+
op->type == GGML_TYPE_F32 &&
10825+
ggml_is_contiguous(op->src[0]) &&
10826+
ggml_is_contiguous(op->src[1]) &&
1081310827
ggml_is_contiguous(op));
1081410828
default:
1081510829
return false;

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#version 450
22

3-
#define USE_COLLECTIVES
4-
53
#ifdef USE_COLLECTIVES
64
#extension GL_KHR_shader_subgroup_shuffle: enable
75
#endif
@@ -12,7 +10,7 @@
1210
#define SHMEM_PAD 0
1311

1412
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
15-
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout]
13+
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout]
1614
layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; // src1 - input: [W, H, Cin, N] -- channel_first format
1715
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; // dst - result: [OW, OH, Cout, N]
1816

@@ -21,7 +19,7 @@ layout (push_constant) uniform parameter {
2119
uint32_t Cout;
2220
uint32_t Cin;
2321
uint32_t N;
24-
22+
2523
// Tensor spatial sizes: kernel, input, output
2624
uint32_t KW;
2725
uint32_t KH;
@@ -59,6 +57,7 @@ layout(constant_id = 2) const uint BS_CRS = 16;
5957
layout(constant_id = 3) const uint BS_NPQ = 128;
6058
// Thread-tile sizes
6159
layout(constant_id = 4) const uint TS_K = 8;
60+
layout(constant_id = 5) const uint use_collectives = 1;
6261

6362
uint32_t tid = gl_LocalInvocationID.x;
6463
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
@@ -122,31 +121,48 @@ uint32_t Br = tid / BS_NPQ;
122121
uint32_t Bc = tid % BS_NPQ;
123122
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
124123

125-
void main(){\
124+
void main(){
126125
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
127126
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
128127
regC[T_ly][T_lx] = 0.0;
129128
}
130129
}
131-
/* Advance block in CRS dim */\
130+
/* Advance block in CRS dim */
132131
for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){
132+
uint32_t CRS_idx_a;
133+
uint32_t Cin_idx_a;
134+
uint32_t KH_idx_a;
135+
uint32_t KW_idx_a;
136+
133137
#ifdef USE_COLLECTIVES
134-
uint32_t cached_CRS_idx = B_idx_CRS*BS_CRS + gl_SubgroupInvocationID;
135-
uint32_t cached_Cin_idx = cached_CRS_idx / (p.KW*p.KH);
136-
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx*p.KW*p.KH);
137-
uint32_t cached_KH_idx = cached_CRS_remainder / p.KW;
138-
uint32_t cached_KW_idx = cached_CRS_remainder - cached_KH_idx*p.KW;
139-
140-
uint32_t CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
141-
uint32_t Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
142-
uint32_t KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
143-
uint32_t KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
138+
uint32_t cached_CRS_idx;
139+
uint32_t cached_Cin_idx;
140+
uint32_t cached_KH_idx;
141+
uint32_t cached_KW_idx;
142+
if(use_collectives == 1){
143+
cached_CRS_idx = B_idx_CRS*BS_CRS + gl_SubgroupInvocationID;
144+
cached_Cin_idx = cached_CRS_idx / (p.KW*p.KH);
145+
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx*p.KW*p.KH);
146+
cached_KH_idx = cached_CRS_remainder / p.KW;
147+
cached_KW_idx = cached_CRS_remainder - cached_KH_idx*p.KW;
148+
149+
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
150+
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
151+
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
152+
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
153+
}else{
154+
CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A)
155+
Cin_idx_a = CRS_idx_a / (p.KW*p.KH);
156+
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH;
157+
KH_idx_a = CRS_remainder / p.KW;
158+
KW_idx_a = CRS_remainder - KH_idx_a*p.KW;
159+
}
144160
#else
145-
uint32_t CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A)
146-
uint32_t Cin_idx_a = CRS_idx_a / (p.KW*p.KH);
147-
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH;
148-
uint32_t KH_idx_a = CRS_remainder / p.KW;
149-
uint32_t KW_idx_a = CRS_remainder - KH_idx_a*p.KW;
161+
CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A)
162+
Cin_idx_a = CRS_idx_a / (p.KW*p.KH);
163+
CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH;
164+
KH_idx_a = CRS_remainder / p.KW;
165+
KW_idx_a = CRS_remainder - KH_idx_a*p.KW;
150166
#endif
151167

152168
/* Load kernel to A_block: (BS_K x BS_CRS)*/
@@ -170,20 +186,32 @@ void main(){\
170186
uint32_t NPQ_remainder = NPQ_idx - N_idx*p.OH*p.OW;
171187
uint32_t OH_idx = NPQ_remainder / p.OW;
172188
uint32_t OW_idx = NPQ_remainder - OH_idx*p.OW;
173-
189+
190+
uint32_t CRS_idx_b;
191+
uint32_t Cin_idx_b;
192+
uint32_t KH_idx_b;
193+
uint32_t KW_idx_b;
174194
#ifdef USE_COLLECTIVES
175-
uint32_t CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
176-
uint32_t Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
177-
uint32_t KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
178-
uint32_t KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
195+
if(use_collectives == 1){
196+
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
197+
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
198+
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
199+
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
200+
}else{
201+
CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */
202+
Cin_idx_b = CRS_idx_b / (p.KW*p.KH);
203+
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b*p.KW*p.KH;
204+
KH_idx_b = CRS_remainder / p.KW;
205+
KW_idx_b = CRS_remainder - KH_idx_b*p.KW;
206+
}
179207
#else
180-
uint32_t CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */
181-
uint32_t Cin_idx_b = CRS_idx_b / (p.KW*p.KH);
208+
CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */
209+
Cin_idx_b = CRS_idx_b / (p.KW*p.KH);
182210
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b*p.KW*p.KH;
183-
uint32_t KH_idx_b = CRS_remainder / p.KW;
184-
uint32_t KW_idx_b = CRS_remainder - KH_idx_b*p.KW;
211+
KH_idx_b = CRS_remainder / p.KW;
212+
KW_idx_b = CRS_remainder - KH_idx_b*p.KW;
185213
#endif
186-
214+
187215
uint32_t H_idx = OH_idx*p.s1 + KH_idx_b*p.d1 - p.p1;
188216
uint32_t W_idx = OW_idx*p.s0 + KW_idx_b*p.d0 - p.p0;
189217
uint32_t src_idx = min(max(W_idx + H_idx*p.nb11 + Cin_idx_b*p.nb12 + N_idx*p.nb13, 0), p.Cin*p.N*p.W*p.H-1);
@@ -223,4 +251,4 @@ void main(){\
223251
}
224252
}
225253
}
226-
}
254+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ void process_shaders() {
650650

651651
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
652652

653-
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
653+
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
654654

655655
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
656656
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

0 commit comments

Comments
 (0)