Skip to content

Commit 5aa1105

Browse files
authored
vulkan: fix build when using glslang that does not support coopmat2 (ggml-org#15062)
1 parent d31192b commit 5aa1105

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3096,9 +3096,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
30963096
uint32_t conv2d_SHMEM_PAD = 4;
30973097
bool conv2d_UNROLL = true;
30983098

3099+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
30993100
if (device->coopmat2) {
31003101
conv2d_SHMEM_PAD = 8; // 8 float16_t
31013102
}
3103+
#endif
31023104

31033105
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
31043106
conv2d_SHMEM_PAD = 0;
@@ -3158,14 +3160,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
31583160
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
31593161
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
31603162

3163+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
31613164
if (device->coopmat2) {
31623165
ggml_vk_create_pipeline(
31633166
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
31643167
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
31653168
ggml_vk_create_pipeline(
31663169
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
31673170
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3168-
} else if (conv2d_UNROLL) {
3171+
} else
3172+
#endif
3173+
if (conv2d_UNROLL) {
31693174
ggml_vk_create_pipeline(
31703175
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
31713176
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,10 @@ void process_shaders() {
661661
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
662662
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
663663

664+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
664665
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
665666
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
667+
#endif
666668

667669
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
668670
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

0 commit comments

Comments
 (0)