Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rwkv6: add wkv6 support for Vulkan backend #10829

Merged
merged 7 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
rwkv_wkv6 vulkan shader
  • Loading branch information
uniartisan authored and MollySophia committed Dec 13, 2024
commit 4651f5e2f29b32e24c69c511d0bacb14d29e6008
165 changes: 164 additions & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ struct vk_device_struct {
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;

// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
Expand Down Expand Up @@ -523,6 +524,15 @@ struct vk_op_pool2d_push_constants {
int32_t p0; int32_t p1;
};


struct vk_op_rwkv_wkv6_push_constants {
uint32_t B; // Batch size (原n_seqs)
uint32_t T; // Sequence length
uint32_t C; // Total channels
uint32_t H; // Number of heads (原HEADS)
};


// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
Expand Down Expand Up @@ -1942,6 +1952,20 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(
device,
device->pipeline_rwkv_wkv6_f32,
"rwkv_wkv6_f32",
rwkv_wkv6_f32_len,
rwkv_wkv6_f32_data,
"main",
7,
sizeof(vk_op_rwkv_wkv6_push_constants),
{64, 1, 1}, // work group
{device->subgroup_size},
1
);

for (auto &c : compiles) {
c.wait();
}
Expand Down Expand Up @@ -4917,6 +4941,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_pool2d_f32;
}
return nullptr;
case GGML_OP_RWKV_WKV6:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rwkv_wkv6_f32;
}
return nullptr;
case GGML_OP_LEAKY_RELU:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32;
Expand Down Expand Up @@ -5319,6 +5348,127 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}



template<typename PC>
static void ggml_vk_op_f32_rwkv6(
ggml_backend_vk_context * ctx,
vk_context& subctx,
ggml_tensor * dst,
const PC&& pc,
bool dryrun = false) {

// Get source tensors
const ggml_tensor * k = dst->src[0]; // keys
const ggml_tensor * v = dst->src[1]; // values
const ggml_tensor * r = dst->src[2]; // reset gates
const ggml_tensor * tf = dst->src[3]; // time first
const ggml_tensor * td = dst->src[4]; // time decay
const ggml_tensor * state = dst->src[5]; // states

VK_LOG_DEBUG("ggml_vk_op_f32_rwkv6(" << k << ", " << v << ", " << r << ", "
<< tf << ", " << td << ", " << state << ", " << dst << ")");

// Verify input types
GGML_ASSERT(!ggml_is_quantized(k->type));
GGML_ASSERT(!ggml_is_quantized(v->type));
GGML_ASSERT(!ggml_is_quantized(r->type));
GGML_ASSERT(!ggml_is_quantized(tf->type));
GGML_ASSERT(!ggml_is_quantized(td->type));
GGML_ASSERT(!ggml_is_quantized(state->type));
GGML_ASSERT(dst->buffer != nullptr);

// Get pipeline
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
GGML_ASSERT(pipeline != nullptr);

if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
return;
}

// Get buffer contexts
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;

// Get device buffers
vk_buffer d_D = dst_buf_ctx->dev_buffer;
uniartisan marked this conversation as resolved.
Show resolved Hide resolved
vk_buffer d_K = k_buf_ctx->dev_buffer;
vk_buffer d_V = v_buf_ctx->dev_buffer;
vk_buffer d_R = r_buf_ctx->dev_buffer;
vk_buffer d_TF = tf_buf_ctx->dev_buffer;
vk_buffer d_TD = td_buf_ctx->dev_buffer;
vk_buffer d_State = state_buf_ctx->dev_buffer;

// Calculate buffer offsets
const uint64_t k_offset = vk_tensor_offset(k);
uniartisan marked this conversation as resolved.
Show resolved Hide resolved
const uint64_t v_offset = vk_tensor_offset(v);
const uint64_t r_offset = vk_tensor_offset(r);
const uint64_t tf_offset = vk_tensor_offset(tf);
const uint64_t td_offset = vk_tensor_offset(td);
const uint64_t state_offset = vk_tensor_offset(state);
const uint64_t dst_offset = vk_tensor_offset(dst);

// Calculate buffer sizes
const uint64_t k_size = ggml_nbytes(k);
const uint64_t v_size = ggml_nbytes(v);
const uint64_t r_size = ggml_nbytes(r);
const uint64_t tf_size = ggml_nbytes(tf);
const uint64_t td_size = ggml_nbytes(td);
const uint64_t state_size = ggml_nbytes(state);
const uint64_t dst_size = ggml_nbytes(dst);

// Set work elements based on tensor dimensions
std::array<uint32_t, 3> elements = {
(uint32_t)(pc.B*pc.H), // B * H workgroups
1, // 每个workgroup 64个线程
1
};

// Synchronize buffers and dispatch compute pipeline
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_K, k_offset, k_size },
vk_subbuffer{ d_V, v_offset, v_size },
vk_subbuffer{ d_R, r_offset, r_size },
vk_subbuffer{ d_TF, tf_offset, tf_size },
vk_subbuffer{ d_TD, td_offset, td_size },
vk_subbuffer{ d_State, state_offset, state_size },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(PC), &pc, elements);
}

static void ggml_vk_rwkv_wkv6(
ggml_backend_vk_context * ctx,
vk_context& subctx,
ggml_tensor * dst,
bool dryrun = false) {

// Extract dimensions from tensors
const size_t T = dst->src[0]->ne[3]; // Sequence length
const size_t C = dst->ne[0]; // Channel dimension
const size_t HEADS = dst->src[0]->ne[2]; // Number of heads
const size_t n_seqs = dst->src[5]->ne[1]; // Batch size

// Call implementation with push constants
ggml_vk_op_f32_rwkv6<vk_op_rwkv_wkv6_push_constants>(
ctx, subctx, dst,
{
(uint32_t)n_seqs, // B
(uint32_t)T, // T
(uint32_t)C, // C
(uint32_t)HEADS, // H
},
dryrun
);
}


static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
int * op_params = (int *)dst->op_params;

Expand Down Expand Up @@ -6464,6 +6614,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
break;
Expand Down Expand Up @@ -6663,6 +6814,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_FLASH_ATTN_EXT:
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);

break;

case GGML_OP_RWKV_WKV6:
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);

break;
default:
return false;
Expand Down Expand Up @@ -6743,6 +6899,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
buf = tensor->buffer;
Expand Down Expand Up @@ -7610,6 +7767,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
jeffbolznv marked this conversation as resolved.
Show resolved Hide resolved
case GGML_OP_LEAKY_RELU:
return true;
default:
Expand Down Expand Up @@ -8186,7 +8344,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
const float * op_params = (const float *)tensor->op_params;
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
} else {
}
// else if (tensor->op == GGML_OP_RWKV_WKV6) {
// tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
// tensor->src[4], tensor->src[5]);
// }
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");
}
Expand Down
96 changes: 96 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#version 450


layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

layout(push_constant) uniform Parameters {
uint B; // Batch size
uint T; // Sequence length
uint C; // Total number of channels
uint H; // Number of heads
};

layout(set = 0, binding = 0) readonly buffer KBuf { float k[]; };
layout(set = 0, binding = 1) readonly buffer VBuf { float v[]; };
layout(set = 0, binding = 2) readonly buffer RBuf { float r[]; };
layout(set = 0, binding = 3) readonly buffer TimeFBuf { float tf[]; };
layout(set = 0, binding = 4) readonly buffer TimeDBuf { float td[]; };
layout(set = 0, binding = 5) readonly buffer StateBuf { float state_in[]; };
layout(set = 0, binding = 6) buffer DstBuf { float dst[]; };

shared float _k[64], _r[64], _tf[64], _td[64];

void main() {
const uint head_size = 64;
const uint batch_id = gl_WorkGroupID.x / H;
const uint head_id = gl_WorkGroupID.x % H;
const uint tid = gl_LocalInvocationID.x;

const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;

if (tid >= head_size || batch_id >= B || head_id >= H) {
return;
}

// Load state
float state[64]; // Use fixed size matching head_size
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}

_k[tid] = 0.0;
_r[tid] = 0.0;
_td[tid] = 0.0;
barrier();
_tf[tid] = tf[head_id * head_size + tid];
barrier();


// Main loop
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;

for (uint t = start_t; t < end_t; t += C) {
barrier();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
barrier();

const float v_val = v[t];
float y = 0.0;

for (uint j = 0; j < head_size; j += 4) {
// Load values in blocks of 4
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);

// Compute kv products
vec4 kv = k_vec * v_val;

// Accumulate results
vec4 temp = tf_vec * kv + s_vec;
y += dot(r_vec, temp);

// Update state
s_vec = s_vec * td_vec + kv;
state[j] = s_vec.x;
state[j+1] = s_vec.y;
state[j+2] = s_vec.z;
state[j+3] = s_vec.w;
}

dst[t] = y;
}

// Write back state
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}
}