@@ -507,6 +507,7 @@ struct vk_device_struct {
507
507
vk_pipeline pipeline_rwkv_wkv6_f32;
508
508
vk_pipeline pipeline_rwkv_wkv7_f32;
509
509
vk_pipeline pipeline_opt_step_adamw_f32;
510
+ vk_pipeline pipeline_opt_step_sgd_f32;
510
511
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
511
512
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
512
513
vk_pipeline pipeline_conv2d_dw_whcn_f32;
@@ -3085,6 +3086,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
3085
3086
3086
3087
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3087
3088
3089
+ ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3090
+
3088
3091
// conv2d
3089
3092
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
3090
3093
uint32_t conv2d_WG_SIZE = 256;
@@ -7120,7 +7123,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7120
7123
return nullptr;
7121
7124
case GGML_OP_OPT_STEP_SGD:
7122
7125
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7123
- // TODO
7126
+ return ctx->device->pipeline_opt_step_sgd_f32;
7124
7127
}
7125
7128
return nullptr;
7126
7129
case GGML_OP_LEAKY_RELU:
@@ -7599,6 +7602,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7599
7602
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
7600
7603
ggml_vk_sync_buffers(subctx);
7601
7604
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7605
+ } else if (op == GGML_OP_OPT_STEP_SGD) {
7606
+ // OPT_STEP_SGD works on src0, it does not need dst
7607
+ ggml_vk_sync_buffers(subctx);
7608
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
7602
7609
} else if (use_src2) {
7603
7610
ggml_vk_sync_buffers(subctx);
7604
7611
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
@@ -7937,18 +7944,10 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
7937
7944
);
7938
7945
}
7939
7946
7940
- static void ggml_vk_op_f32_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
7941
- GGML_ASSERT(0 && "SGD vulkan unimplemented"); // TODO
7942
- }
7943
-
7944
- static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
7947
+ static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
7945
7948
const size_t n = ggml_nelements(dst->src[0]);
7946
7949
7947
- ggml_vk_op_f32_opt_step_sgd(
7948
- ctx, subctx, dst,
7949
- { (uint32_t)n, 0, 0.0f, 0.0f },
7950
- dryrun
7951
- );
7950
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
7952
7951
}
7953
7952
7954
7953
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9489,6 +9488,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9489
9488
case GGML_OP_LEAKY_RELU:
9490
9489
case GGML_OP_FLASH_ATTN_EXT:
9491
9490
case GGML_OP_OPT_STEP_ADAMW:
9491
+ case GGML_OP_OPT_STEP_SGD:
9492
9492
break;
9493
9493
default:
9494
9494
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -9553,6 +9553,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9553
9553
case GGML_OP_CONV_2D:
9554
9554
case GGML_OP_CONV_2D_DW:
9555
9555
case GGML_OP_LEAKY_RELU:
9556
+ case GGML_OP_OPT_STEP_SGD:
9556
9557
{
9557
9558
// These operations all go through ggml_vk_op_f32, so short-circuit and
9558
9559
// do the only thing needed for the dryrun.
@@ -9800,8 +9801,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9800
9801
break;
9801
9802
9802
9803
case GGML_OP_OPT_STEP_SGD:
9803
- return false; // TODO
9804
- ggml_vk_opt_step_sgd(ctx, compute_ctx, node, dryrun);
9804
+ ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
9805
9805
9806
9806
break;
9807
9807
default:
@@ -9905,10 +9905,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9905
9905
case GGML_OP_REPEAT:
9906
9906
case GGML_OP_REPEAT_BACK:
9907
9907
case GGML_OP_OPT_STEP_ADAMW:
9908
+ case GGML_OP_OPT_STEP_SGD:
9908
9909
buf = tensor->buffer;
9909
9910
break;
9910
- case GGML_OP_OPT_STEP_SGD:
9911
- return false;
9912
9911
case GGML_OP_UNARY:
9913
9912
switch (ggml_get_unary_op(tensor)) {
9914
9913
case GGML_UNARY_OP_SILU:
@@ -11036,6 +11035,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11036
11035
case GGML_OP_SIN:
11037
11036
case GGML_OP_COS:
11038
11037
case GGML_OP_CLAMP:
11038
+ case GGML_OP_LEAKY_RELU:
11039
+ case GGML_OP_OPT_STEP_ADAMW:
11040
+ case GGML_OP_OPT_STEP_SGD:
11039
11041
return op->src[0]->type == GGML_TYPE_F32;
11040
11042
case GGML_OP_UPSCALE:
11041
11043
case GGML_OP_ACC:
@@ -11057,11 +11059,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11057
11059
case GGML_OP_POOL_2D:
11058
11060
case GGML_OP_RWKV_WKV6:
11059
11061
case GGML_OP_RWKV_WKV7:
11060
- case GGML_OP_LEAKY_RELU:
11061
- case GGML_OP_OPT_STEP_ADAMW:
11062
11062
return true;
11063
- case GGML_OP_OPT_STEP_SGD:
11064
- return false;
11065
11063
case GGML_OP_CONV_TRANSPOSE_1D:
11066
11064
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
11067
11065
case GGML_OP_CONV_2D:
0 commit comments