|
| 1 | +#include "ssm-scan.cuh" |
| 2 | + |
| 3 | +// #include <cuda_runtime.h> |
| 4 | +// static __device__ void global_to_shared(const float *src, float *dst) { |
| 5 | +// asm volatile("cp.async."); |
| 6 | +// } |
| 7 | + |
| 8 | +template <size_t splitD, size_t N> |
| 9 | +__global__ void __launch_bounds__(splitD, 2) |
| 10 | + ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, |
| 11 | + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, |
| 12 | + const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, |
| 13 | + const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, |
| 14 | + const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, |
| 15 | + float * __restrict__ dst, const int D, const int L, const int B) { |
| 16 | + const int bidx = blockIdx.x; // split along B |
| 17 | + const int bidy = blockIdx.y; // split along D |
| 18 | + const int tid = threadIdx.x; |
| 19 | + const int wid = tid / 32; |
| 20 | + const int wtid = tid % 32; |
| 21 | + |
| 22 | + extern __shared__ float smem[]; |
| 23 | + const int stride_sA = N + 1; |
| 24 | + const int stride_ss0 = N + 1; |
| 25 | + float * smem_A = smem; |
| 26 | + float * smem_s0 = smem_A + splitD * stride_sA; |
| 27 | + |
| 28 | + const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); |
| 29 | + const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); |
| 30 | + const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); |
| 31 | + const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1); |
| 32 | + const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2)); |
| 33 | + const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2)); |
| 34 | + float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); |
| 35 | + float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); |
| 36 | + |
| 37 | + const int stride_s0 = src0_nb1 / sizeof(float); |
| 38 | + const int stride_x = src1_nb1 / sizeof(float); |
| 39 | + const int stride_dt = src2_nb1 / sizeof(float); |
| 40 | + const int stride_A = src3_nb1 / sizeof(float); |
| 41 | + const int stride_B = src4_nb1 / sizeof(float); |
| 42 | + const int stride_C = src5_nb1 / sizeof(float); |
| 43 | + const int stride_s = stride_s0; |
| 44 | + const int stride_y = stride_x; |
| 45 | + |
| 46 | + // can N not be 16? for example 32? |
| 47 | + if (N == 16) { |
| 48 | +#pragma unroll |
| 49 | + for (int i = 0; i < splitD / 4; i += 2) { |
| 50 | + float value = A_block[(wid * warpSize + i) * stride_A + wtid]; |
| 51 | + // todo: bank conflict |
| 52 | + // I am always confused with how to use the swizzling method to solve |
| 53 | + // bank conflit. Hoping somebody can tell me. |
| 54 | + smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; |
| 55 | + } |
| 56 | +#pragma unroll |
| 57 | + for (int i = 0; i < splitD / 4; i += 2) { |
| 58 | + float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; |
| 59 | + smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + __syncthreads(); |
| 64 | + |
| 65 | + for (int i = 0; i < L; i++) { |
| 66 | + float dt_soft_plus = dt_block[i * stride_dt + tid]; |
| 67 | + if (dt_soft_plus <= 20.0f) { |
| 68 | + dt_soft_plus = log1pf(exp(dt_soft_plus)); |
| 69 | + } |
| 70 | + float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; |
| 71 | + float sumf = 0.0f; |
| 72 | +#pragma unroll |
| 73 | + for (int j = 0; j < N; j++) { |
| 74 | + float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + |
| 75 | + (B_block[i * stride_B + j] * x_dt); |
| 76 | + sumf += state * C_block[i * stride_C + j]; |
| 77 | + if (i == L - 1) { |
| 78 | + s_block[tid * stride_s + j] = state; |
| 79 | + } else { |
| 80 | + smem_s0[tid * stride_ss0 + j] = state; |
| 81 | + } |
| 82 | + } |
| 83 | + __syncthreads(); |
| 84 | + y_block[i * stride_y + tid] = sumf; |
| 85 | + } |
| 86 | +} |
| 87 | + |
| 88 | +static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, |
| 89 | + const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, |
| 90 | + const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, |
| 91 | + const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, |
| 92 | + const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, |
| 93 | + float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) { |
| 94 | + const int threads = 128; |
| 95 | + // todo: consider D cannot be divided,does this situation exist? |
| 96 | + GGML_ASSERT(D % threads == 0); |
| 97 | + const dim3 blocks(B, (D + threads - 1) / threads, 1); |
| 98 | + const int smem_size = (threads * (N + 1) * 2) * sizeof(float); |
| 99 | + if (N == 16) { |
| 100 | + ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( |
| 101 | + src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, |
| 102 | + src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B); |
| 103 | + } else { |
| 104 | + GGML_ABORT("doesn't support N!=16."); |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 109 | + const struct ggml_tensor * src0 = dst->src[0]; // s |
| 110 | + const struct ggml_tensor * src1 = dst->src[1]; // x |
| 111 | + const struct ggml_tensor * src2 = dst->src[2]; // dt |
| 112 | + const struct ggml_tensor * src3 = dst->src[3]; // A |
| 113 | + const struct ggml_tensor * src4 = dst->src[4]; // B |
| 114 | + const struct ggml_tensor * src5 = dst->src[5]; // C |
| 115 | + |
| 116 | + // const int64_t d_state = src0->ne[0]; |
| 117 | + // const int64_t d_inner = src0->ne[1]; |
| 118 | + // const int64_t l = src1->ne[1]; |
| 119 | + // const int64_t b = src0->ne[2]; |
| 120 | + |
| 121 | + const int64_t nc = src0->ne[0]; // d_state |
| 122 | + const int64_t nr = src0->ne[1]; // d_inner |
| 123 | + const int64_t n_t = src1->ne[1]; // number of tokens per sequence |
| 124 | + const int64_t n_s = src0->ne[2]; // number of sequences in the batch |
| 125 | + |
| 126 | + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); |
| 127 | + GGML_ASSERT(src0->nb[0] == sizeof(float)); |
| 128 | + GGML_ASSERT(src1->nb[0] == sizeof(float)); |
| 129 | + GGML_ASSERT(src2->nb[0] == sizeof(float)); |
| 130 | + GGML_ASSERT(src3->nb[0] == sizeof(float)); |
| 131 | + GGML_ASSERT(src4->nb[0] == sizeof(float)); |
| 132 | + GGML_ASSERT(src5->nb[0] == sizeof(float)); |
| 133 | + // required for the dot product between s and C |
| 134 | + GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); |
| 135 | + // required for per-sequence offsets for states |
| 136 | + GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float)); |
| 137 | + // required to get correct offset for state destination (i.e. src1->nb[3]) |
| 138 | + GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float)); |
| 139 | + |
| 140 | + const float * src0_d = (const float *) src0->data; |
| 141 | + const float * src1_d = (const float *) src1->data; |
| 142 | + const float * src2_d = (const float *) src2->data; |
| 143 | + const float * src3_d = (const float *) src3->data; |
| 144 | + const float * src4_d = (const float *) src4->data; |
| 145 | + const float * src5_d = (const float *) src5->data; |
| 146 | + float * dst_d = (float *) dst->data; |
| 147 | + cudaStream_t stream = ctx.stream(); |
| 148 | + |
| 149 | + GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 150 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 151 | + |
| 152 | + ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], |
| 153 | + src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], |
| 154 | + src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); |
| 155 | +} |
0 commit comments