@@ -1185,6 +1185,14 @@ struct vk_staging_memcpy {
1185
1185
size_t n;
1186
1186
};
1187
1187
1188
+ struct vk_staging_memset {
1189
+ vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {}
1190
+
1191
+ void * dst;
1192
+ uint32_t val;
1193
+ size_t n;
1194
+ };
1195
+
1188
1196
struct vk_context_struct {
1189
1197
vk_submission * s;
1190
1198
std::vector<vk_sequence> seqs;
@@ -1193,6 +1201,7 @@ struct vk_context_struct {
1193
1201
1194
1202
std::vector<vk_staging_memcpy> in_memcpys;
1195
1203
std::vector<vk_staging_memcpy> out_memcpys;
1204
+ std::vector<vk_staging_memset> memsets;
1196
1205
1197
1206
vk_command_pool * p {};
1198
1207
};
@@ -5196,6 +5205,14 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
5196
5205
}
5197
5206
}
5198
5207
5208
+ static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr) {
5209
+ if (memsets == nullptr) {
5210
+ memset(dst, val, size);
5211
+ } else {
5212
+ memsets->emplace_back(dst, val, size);
5213
+ }
5214
+ }
5215
+
5199
5216
static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
5200
5217
if (device->sync_staging == nullptr || device->sync_staging->size < size) {
5201
5218
VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
@@ -5391,6 +5408,10 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
5391
5408
memcpy(cpy.dst, cpy.src, cpy.n);
5392
5409
}
5393
5410
5411
+ for (auto& mset : subctx->memsets) {
5412
+ memset(mset.dst, mset.val, mset.n);
5413
+ }
5414
+
5394
5415
ggml_vk_submit(subctx, dst->device->fence);
5395
5416
VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
5396
5417
dst->device->device.resetFences({ dst->device->fence });
@@ -5530,12 +5551,25 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
5530
5551
static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
5531
5552
VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
5532
5553
5554
+ if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
5555
+ dst->device->uma) {
5556
+ deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets);
5557
+ return;
5558
+ }
5559
+
5560
+ // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
5533
5561
ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
5534
5562
}
5535
5563
5536
5564
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
5537
5565
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
5538
5566
5567
+ if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
5568
+ dst->device->uma) {
5569
+ memset((uint8_t*)dst->ptr + offset, c, size);
5570
+ return;
5571
+ }
5572
+
5539
5573
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
5540
5574
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
5541
5575
ggml_vk_ctx_begin(dst->device, subctx);
@@ -11170,6 +11204,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
11170
11204
memcpy(cpy.dst, cpy.src, cpy.n);
11171
11205
}
11172
11206
11207
+ for (auto& mset : subctx->memsets) {
11208
+ memset(mset.dst, mset.val, mset.n);
11209
+ }
11210
+
11173
11211
if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) {
11174
11212
ggml_vk_submit(subctx, ctx->almost_ready_fence);
11175
11213
ctx->almost_ready_fence_pending = true;
@@ -11192,6 +11230,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
11192
11230
}
11193
11231
subctx->in_memcpys.clear();
11194
11232
subctx->out_memcpys.clear();
11233
+ subctx->memsets.clear();
11195
11234
}
11196
11235
11197
11236
return true;
0 commit comments