forked from LostRuins/koboldcpp
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #324 from A3shTnT/faster_ssm_scan
Faster ssm scan
- Loading branch information
Showing
5 changed files
with
322 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
#include "ssm_conv.cuh" | ||
|
||
template <size_t split_d_inner, size_t d_conv> | ||
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, | ||
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, | ||
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
const int nc, const int ncs, const int nr, const int n_t, const int n_s) { | ||
const int tid = threadIdx.x; | ||
const int bidx = blockIdx.x; | ||
const int bidy = blockIdx.y; | ||
|
||
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); | ||
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1); | ||
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0); | ||
|
||
const int stride_x = src0_nb1 / sizeof(float); | ||
const int stride_w = src1_nb1 / sizeof(float); | ||
const int stride_y = dst_nb1 / sizeof(float); | ||
|
||
float x[d_conv] = { 0.0f }; | ||
float w[d_conv] = { 0.0f }; | ||
|
||
#pragma unroll | ||
for (int j = 0; j < d_conv; j++) { | ||
w[j] = w_block[tid * stride_w + j]; | ||
} | ||
|
||
for (int i = 0; i < n_t; i++) { | ||
float sumf = 0.0f; | ||
|
||
if (i == 0) { | ||
for (int j = 0; j < d_conv; j++) { | ||
x[j] = x_block[tid * stride_x + j]; | ||
} | ||
} else { | ||
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; | ||
} | ||
|
||
#pragma unroll | ||
for (int j = 0; j < d_conv; j++) { | ||
sumf += x[(i + j) % d_conv] * w[j]; | ||
} | ||
y_block[i * stride_y + tid] = sumf; | ||
} | ||
} | ||
|
||
template <size_t split_d_inner, size_t d_conv, size_t split_n_t> | ||
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, | ||
const int src0_nb0, const int src0_nb1, const int src0_nb2, | ||
const int src1_nb1, float * __restrict__ dst, const int dst_nb0, | ||
const int dst_nb1, const int dst_nb2, const int nc, const int ncs, | ||
const int nr, const int n_t, const int n_s) { | ||
const int tid = threadIdx.x; | ||
const int bidx = blockIdx.x; | ||
const int bidy = blockIdx.y; | ||
const int bidz = blockIdx.z; | ||
|
||
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 + | ||
bidz * split_n_t * src0_nb0); | ||
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1); | ||
float * y_block = | ||
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0); | ||
|
||
const int stride_x = src0_nb1 / sizeof(float); | ||
const int stride_w = src1_nb1 / sizeof(float); | ||
const int stride_y = dst_nb1 / sizeof(float); | ||
|
||
float x[d_conv] = { 0.0f }; | ||
float w[d_conv] = { 0.0f }; | ||
|
||
#pragma unroll | ||
for (int j = 0; j < d_conv; j++) { | ||
w[j] = w_block[tid * stride_w + j]; | ||
} | ||
|
||
#pragma unroll | ||
for (int i = 0; i < split_n_t; i++) { | ||
if (bidz * split_n_t + i < n_t) { | ||
float sumf = 0.0f; | ||
|
||
if (i == 0) { | ||
for (int j = 0; j < d_conv; j++) { | ||
x[j] = x_block[tid * stride_x + j]; | ||
} | ||
} else { | ||
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; | ||
} | ||
|
||
#pragma unroll | ||
for (int j = 0; j < d_conv; j++) { | ||
sumf += x[(i + j) % d_conv] * w[j]; | ||
} | ||
y_block[i * stride_y + tid] = sumf; | ||
} | ||
} | ||
} | ||
|
||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, | ||
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, | ||
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t, | ||
const int n_s, cudaStream_t stream) { | ||
const int threads = 128; | ||
GGML_ASSERT(nr % threads == 0); | ||
|
||
if (n_t <= 32) { | ||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); | ||
if (nc == 4) { | ||
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, | ||
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t, | ||
n_s); | ||
} else { | ||
GGML_ABORT("Only support kernel size = 4 now."); | ||
} | ||
} else { | ||
if (nc == 4) { | ||
const int split_n_t = 32; | ||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); | ||
ssm_conv_long_token_f32<threads, 4, split_n_t> | ||
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, | ||
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s); | ||
} else { | ||
GGML_ABORT("Only support kernel size = 4 right now."); | ||
} | ||
} | ||
} | ||
|
||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const struct ggml_tensor * src0 = dst->src[0]; // conv_x | ||
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight | ||
|
||
const int nc = src1->ne[0]; // d_conv | ||
const int ncs = src0->ne[0]; // d_conv - 1 + n_t | ||
const int nr = src0->ne[1]; // d_inner | ||
const int n_t = dst->ne[1]; // tokens per sequence | ||
const int n_s = dst->ne[2]; // number of sequences in the batch | ||
|
||
GGML_ASSERT(dst->ne[0] == nr); | ||
GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); | ||
|
||
const float * src0_d = (const float *) src0->data; | ||
const float * src1_d = (const float *) src1->data; | ||
float * dst_d = (float *) dst->data; | ||
cudaStream_t stream = ctx.stream(); | ||
|
||
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], | ||
dst->nb[2], nc, ncs, nr, n_t, n_s, stream); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#include "common.cuh" | ||
|
||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
#include "ssm_scan.cuh" | ||
|
||
// #include <cuda_runtime.h> | ||
// static __device__ void global_to_shared(const float *src, float *dst) { | ||
// asm volatile("cp.async."); | ||
// } | ||
|
||
template <size_t splitD, size_t N> | ||
__global__ void __launch_bounds__(splitD, 2) | ||
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, | ||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, | ||
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, | ||
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, | ||
float * __restrict__ dst, const int D, const int L, const int B) { | ||
const int bidx = blockIdx.x; // split along B | ||
const int bidy = blockIdx.y; // split along D | ||
const int tid = threadIdx.x; | ||
const int wid = tid / 32; | ||
const int wtid = tid % 32; | ||
|
||
extern __shared__ float smem[]; | ||
const int stride_sA = N + 1; | ||
const int stride_ss0 = N + 1; | ||
float * smem_A = smem; | ||
float * smem_s0 = smem_A + splitD * stride_sA; | ||
|
||
const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); | ||
const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); | ||
const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); | ||
const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1); | ||
const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2)); | ||
const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2)); | ||
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); | ||
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); | ||
|
||
const int stride_s0 = src0_nb1 / sizeof(float); | ||
const int stride_x = src1_nb1 / sizeof(float); | ||
const int stride_dt = src2_nb1 / sizeof(float); | ||
const int stride_A = src3_nb1 / sizeof(float); | ||
const int stride_B = src4_nb1 / sizeof(float); | ||
const int stride_C = src5_nb1 / sizeof(float); | ||
const int stride_s = stride_s0; | ||
const int stride_y = stride_x; | ||
|
||
// can N not be 16? for example 32? | ||
if (N == 16) { | ||
#pragma unroll | ||
for (int i = 0; i < splitD / 4; i += 2) { | ||
float value = A_block[(wid * warpSize + i) * stride_A + wtid]; | ||
// todo: bank conflict | ||
// I am always confused with how to use the swizzling method to solve | ||
// bank conflit. Hoping somebody can tell me. | ||
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; | ||
} | ||
#pragma unroll | ||
for (int i = 0; i < splitD / 4; i += 2) { | ||
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; | ||
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; | ||
} | ||
} | ||
|
||
__syncthreads(); | ||
|
||
for (int i = 0; i < L; i++) { | ||
float dt_soft_plus = dt_block[i * stride_dt + tid]; | ||
if (dt_soft_plus <= 20.0f) { | ||
dt_soft_plus = log1pf(exp(dt_soft_plus)); | ||
} | ||
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; | ||
float sumf = 0.0f; | ||
#pragma unroll | ||
for (int j = 0; j < N; j++) { | ||
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + | ||
(B_block[i * stride_B + j] * x_dt); | ||
sumf += state * C_block[i * stride_C + j]; | ||
if (i == L - 1) { | ||
s_block[tid * stride_s + j] = state; | ||
} else { | ||
smem_s0[tid * stride_ss0 + j] = state; | ||
} | ||
} | ||
__syncthreads(); | ||
y_block[i * stride_y + tid] = sumf; | ||
} | ||
} | ||
|
||
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, | ||
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, | ||
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, | ||
float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) { | ||
const int threads = 128; | ||
// todo: consider D cannot be divided,does this situation exist? | ||
GGML_ASSERT(D % threads == 0); | ||
const dim3 blocks(B, (D + threads - 1) / threads, 1); | ||
const int smem_size = (threads * (N + 1) * 2) * sizeof(float); | ||
if (N == 16) { | ||
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( | ||
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, | ||
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B); | ||
} else { | ||
GGML_ABORT("doesn't support N!=16."); | ||
} | ||
} | ||
|
||
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const struct ggml_tensor * src0 = dst->src[0]; // s | ||
const struct ggml_tensor * src1 = dst->src[1]; // x | ||
const struct ggml_tensor * src2 = dst->src[2]; // dt | ||
const struct ggml_tensor * src3 = dst->src[3]; // A | ||
const struct ggml_tensor * src4 = dst->src[4]; // B | ||
const struct ggml_tensor * src5 = dst->src[5]; // C | ||
|
||
// const int64_t d_state = src0->ne[0]; | ||
// const int64_t d_inner = src0->ne[1]; | ||
// const int64_t l = src1->ne[1]; | ||
// const int64_t b = src0->ne[2]; | ||
|
||
const int64_t nc = src0->ne[0]; // d_state | ||
const int64_t nr = src0->ne[1]; // d_inner | ||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence | ||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch | ||
|
||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); | ||
GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src2->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src3->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src4->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src5->nb[0] == sizeof(float)); | ||
// required for the dot product between s and C | ||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); | ||
// required for per-sequence offsets for states | ||
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float)); | ||
// required to get correct offset for state destination (i.e. src1->nb[3]) | ||
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float)); | ||
|
||
const float * src0_d = (const float *) src0->data; | ||
const float * src1_d = (const float *) src1->data; | ||
const float * src2_d = (const float *) src2->data; | ||
const float * src3_d = (const float *) src3->data; | ||
const float * src4_d = (const float *) src4->data; | ||
const float * src5_d = (const float *) src5->data; | ||
float * dst_d = (float *) dst->data; | ||
cudaStream_t stream = ctx.stream(); | ||
|
||
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
|
||
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], | ||
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], | ||
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#include "common.cuh" | ||
|
||
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |