Skip to content

Commit d3e88bc

Browse files
committed
Implement repeat_back SYCL operation and minor fixes
1 parent 8d51e18 commit d3e88bc

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

ggml/src/ggml-sycl/repeat_back.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,16 @@
44
#include "common.hpp"
55

66
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
7+
78
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
89
GGML_ASSERT(dst->type == GGML_TYPE_F32);
910

1011
const float * src0_dd = (const float *) dst->src[0]->data;
1112
float * dst_dd = (float *) dst->data;
1213

13-
const int64_t ne0 = dst->ne[0];
14-
const int64_t ne1 = dst->ne[1];
15-
const int64_t ne2 = dst->ne[2];
16-
const int64_t ne3 = dst->ne[3];
17-
const int64_t ne00 = dst->src[0]->ne[0];
18-
const int64_t ne01 = dst->src[0]->ne[1];
19-
const int64_t ne02 = dst->src[0]->ne[2];
20-
const int64_t ne03 = dst->src[0]->ne[3];
14+
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
15+
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
16+
ne03 = dst->src[0]->ne[3];
2117

2218
const int nr0 = (int) (ne00 / ne0);
2319
const int nr1 = (int) (ne01 / ne1);
@@ -29,7 +25,6 @@ void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst
2925
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
3026

3127
queue_ptr stream = ctx.stream();
32-
stream->memset(dst_dd, 0, total * sizeof(float));
3328

3429
stream->parallel_for(
3530
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),

0 commit comments

Comments
 (0)