|
3 | 3 |
|
4 | 4 | #include <assert.h> |
5 | 5 |
|
6 | | -#include <pybind11/numpy.h> |
7 | 6 | #include <pybind11/pybind11.h> |
8 | 7 | #include <pybind11/stl.h> |
| 8 | +#include <pybind11/numpy.h> |
9 | 9 | #include <torch/extension.h> |
10 | 10 |
|
11 | 11 | #include <cuda.h> |
12 | 12 |
|
13 | | -__device__ void save_blocks(int *block_offset, int range_start, int range_end, |
14 | | - int block_size, int &block_count) { |
15 | | - for (int idx = range_start; idx < range_end; idx += block_size) { |
16 | | - block_offset[block_count++] = idx; |
17 | | - } |
| 13 | +__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { |
| 14 | + for (int idx = range_start; idx < range_end; idx += block_size) { |
| 15 | + block_offset[block_count++] = idx; |
| 16 | + } |
18 | 17 | } |
19 | 18 |
|
20 | 19 | __global__ void convert_vertical_slash_indexes_kernel( |
21 | | - const int *seqlens, // [BATCH, ] |
22 | | - const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
23 | | - const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
24 | | - int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
25 | | - int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] |
26 | | - int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
27 | | - int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] |
28 | | - int N_HEADS, int N_ROWS, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int NNZ_V, |
29 | | - int NNZ_S) { |
30 | | - const int batch_idx = blockIdx.y; |
31 | | - const int head_idx = blockIdx.x; |
32 | | - const int group_idx = blockIdx.z; |
| 20 | + const int* seqlens, // [BATCH, ] |
| 21 | + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 22 | + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 23 | + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 24 | + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] |
| 25 | + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 26 | + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] |
| 27 | + int N_HEADS, |
| 28 | + int N_ROWS, |
| 29 | + int BLOCK_SIZE_M, |
| 30 | + int BLOCK_SIZE_N, |
| 31 | + int NNZ_V, |
| 32 | + int NNZ_S |
| 33 | +) { |
| 34 | + const int batch_idx = blockIdx.y; |
| 35 | + const int head_idx = blockIdx.x; |
| 36 | + const int group_idx = blockIdx.z; |
33 | 37 |
|
34 | | - int seqlen = seqlens[batch_idx]; |
35 | | - int block_idx_m = group_idx * blockDim.x + threadIdx.x; |
36 | | - int start_m = block_idx_m * BLOCK_SIZE_M; |
37 | | - if (start_m >= seqlen) { |
38 | | - return; |
39 | | - } |
40 | | - int end_m = start_m + BLOCK_SIZE_M; |
41 | | - vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; |
42 | | - slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; |
43 | | - int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; |
44 | | - block_count += row_offset; |
45 | | - block_offset += row_offset * NNZ_S; |
46 | | - column_count += row_offset; |
47 | | - column_index += row_offset * NNZ_V; |
| 38 | + int seqlen = seqlens[batch_idx]; |
| 39 | + int block_idx_m = group_idx * blockDim.x + threadIdx.x; |
| 40 | + int start_m = block_idx_m * BLOCK_SIZE_M; |
| 41 | + if (start_m >= seqlen) { |
| 42 | + return; |
| 43 | + } |
| 44 | + int end_m = start_m + BLOCK_SIZE_M; |
| 45 | + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; |
| 46 | + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; |
| 47 | + int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; |
| 48 | + block_count += row_offset; |
| 49 | + block_offset += row_offset * NNZ_S; |
| 50 | + column_count += row_offset; |
| 51 | + column_index += row_offset * NNZ_V; |
48 | 52 |
|
49 | | - int tmp_col_cnt = 0, tmp_blk_cnt = 0; |
50 | | - int s = 0, v = 0; |
51 | | - int v_idx = vertical_indexes[v++]; |
52 | | - int s_idx = slash_indexes[s++]; |
53 | | - while (s_idx >= end_m) { |
54 | | - s_idx = slash_indexes[s++]; |
55 | | - } |
56 | | - s_idx = max(end_m - s_idx, BLOCK_SIZE_M); |
57 | | - int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; |
58 | | - while (1) { |
59 | | - if (v_idx < range_end) { |
60 | | - if (v_idx < range_start) { |
61 | | - column_index[tmp_col_cnt++] = v_idx; |
62 | | - } |
63 | | - if (v < NNZ_V) { |
64 | | - v_idx = vertical_indexes[v++]; |
65 | | - } else { |
66 | | - v_idx = end_m + BLOCK_SIZE_M; |
67 | | - } |
68 | | - } else { |
69 | | - if (s < NNZ_S) { |
70 | | - s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); |
71 | | - } else { |
72 | | - save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, |
73 | | - tmp_blk_cnt); |
74 | | - break; |
75 | | - } |
76 | | - if (s_idx > range_end + BLOCK_SIZE_M) { |
77 | | - save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, |
78 | | - tmp_blk_cnt); |
79 | | - range_start = s_idx - BLOCK_SIZE_M; |
80 | | - range_end = s_idx; |
81 | | - } else if (s_idx > range_end) { |
82 | | - range_end += BLOCK_SIZE_M; |
83 | | - } |
| 53 | + int tmp_col_cnt = 0, tmp_blk_cnt = 0; |
| 54 | + int s = 0, v = 0; |
| 55 | + int v_idx = vertical_indexes[v++]; |
| 56 | + int s_idx = slash_indexes[s++]; |
| 57 | + while (s_idx >= end_m) { |
| 58 | + s_idx = slash_indexes[s++]; |
| 59 | + } |
| 60 | + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); |
| 61 | + int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; |
| 62 | + while (1) { |
| 63 | + if (v_idx < range_end) { |
| 64 | + if (v_idx < range_start) { |
| 65 | + column_index[tmp_col_cnt++] = v_idx; |
| 66 | + } |
| 67 | + if (v < NNZ_V) { |
| 68 | + v_idx = vertical_indexes[v++]; |
| 69 | + } else { |
| 70 | + v_idx = end_m + BLOCK_SIZE_M; |
| 71 | + } |
| 72 | + } else { |
| 73 | + if (s < NNZ_S) { |
| 74 | + s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); |
| 75 | + } else { |
| 76 | + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); |
| 77 | + break; |
| 78 | + } |
| 79 | + if (s_idx > range_end + BLOCK_SIZE_M) { |
| 80 | + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); |
| 81 | + range_start = s_idx - BLOCK_SIZE_M; |
| 82 | + range_end = s_idx; |
| 83 | + } else if (s_idx > range_end) { |
| 84 | + range_end += BLOCK_SIZE_M; |
| 85 | + } |
| 86 | + } |
84 | 87 | } |
85 | | - } |
86 | 88 |
|
87 | | - block_count[0] = tmp_blk_cnt; |
88 | | - column_count[0] = tmp_col_cnt; |
| 89 | + block_count[0] = tmp_blk_cnt; |
| 90 | + column_count[0] = tmp_col_cnt; |
89 | 91 | } |
90 | 92 |
|
91 | 93 | void convert_vertical_slash_indexes_64x64( |
92 | | - const int *seqlens, // [BATCH, ] |
93 | | - const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
94 | | - const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
95 | | - int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
96 | | - int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] |
97 | | - int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
98 | | - int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] |
99 | | - int BATCH_SIZE, int N_HEADS, int N_ROWS, int NNZ_V, int NNZ_S) { |
100 | | - const int BLOCK_SIZE_M = 64; |
101 | | - const int BLOCK_SIZE_N = 64; |
102 | | - const int N_THREADS = 64; |
103 | | - const dim3 dimBlock(N_THREADS); |
104 | | - const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); |
105 | | - convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>( |
106 | | - seqlens, vertical_indexes, slash_indexes, block_count, block_offset, |
107 | | - column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, |
108 | | - NNZ_V, NNZ_S); |
| 94 | + const int* seqlens, // [BATCH, ] |
| 95 | + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 96 | + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 97 | + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 98 | + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] |
| 99 | + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 100 | + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] |
| 101 | + int BATCH_SIZE, |
| 102 | + int N_HEADS, |
| 103 | + int N_ROWS, |
| 104 | + int NNZ_V, |
| 105 | + int NNZ_S |
| 106 | +) { |
| 107 | + const int BLOCK_SIZE_M = 64; |
| 108 | + const int BLOCK_SIZE_N = 64; |
| 109 | + const int N_THREADS = 64; |
| 110 | + const dim3 dimBlock(N_THREADS); |
| 111 | + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); |
| 112 | + convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>( |
| 113 | + seqlens, vertical_indexes, slash_indexes, |
| 114 | + block_count, block_offset, column_count, column_index, |
| 115 | + N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S |
| 116 | + ); |
109 | 117 | } |
110 | 118 |
|
111 | 119 | std::vector<at::Tensor> convert_vertical_slash_indexes( |
112 | | - torch::Tensor seqlens, // [BATCH, ] |
113 | | - torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
114 | | - torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
115 | | - int context_size, int block_size_M, int block_size_N) { |
116 | | - assert(block_size_M == 64); |
117 | | - assert(block_size_N == 64); |
| 120 | + torch::Tensor seqlens, // [BATCH, ] |
| 121 | + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 122 | + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 123 | + int context_size, |
| 124 | + int block_size_M, |
| 125 | + int block_size_N |
| 126 | +) { |
| 127 | + assert(block_size_M == 64); |
| 128 | + assert(block_size_N == 64); |
118 | 129 |
|
119 | | - cudaSetDevice(seqlens.get_device()); |
| 130 | + cudaSetDevice(seqlens.get_device()); |
120 | 131 |
|
121 | | - int batch_size = slash_indexes.size(0); |
122 | | - int num_heads = slash_indexes.size(1); |
123 | | - int nnz_slash = slash_indexes.size(2); |
124 | | - int nnz_vertical = vertical_indexes.size(2); |
125 | | - int num_rows = (context_size + block_size_M - 1) / block_size_M; |
| 132 | + int batch_size = slash_indexes.size(0); |
| 133 | + int num_heads = slash_indexes.size(1); |
| 134 | + int nnz_slash = slash_indexes.size(2); |
| 135 | + int nnz_vertical = vertical_indexes.size(2); |
| 136 | + int num_rows = (context_size + block_size_M - 1) / block_size_M; |
126 | 137 |
|
127 | | - torch::Tensor block_count = |
128 | | - torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); |
129 | | - torch::Tensor block_offset = torch::zeros( |
130 | | - {batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); |
131 | | - torch::Tensor column_count = |
132 | | - torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); |
133 | | - torch::Tensor column_index = torch::zeros( |
134 | | - {batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); |
| 138 | + torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); |
| 139 | + torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); |
| 140 | + torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); |
| 141 | + torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); |
135 | 142 |
|
136 | | - convert_vertical_slash_indexes_64x64( |
137 | | - seqlens.data_ptr<int>(), vertical_indexes.data_ptr<int>(), |
138 | | - slash_indexes.data_ptr<int>(), block_count.data_ptr<int>(), |
139 | | - block_offset.data_ptr<int>(), column_count.data_ptr<int>(), |
140 | | - column_index.data_ptr<int>(), batch_size, num_heads, num_rows, |
141 | | - nnz_vertical, nnz_slash); |
| 143 | + convert_vertical_slash_indexes_64x64( |
| 144 | + seqlens.data_ptr<int>(), |
| 145 | + vertical_indexes.data_ptr<int>(), |
| 146 | + slash_indexes.data_ptr<int>(), |
| 147 | + block_count.data_ptr<int>(), |
| 148 | + block_offset.data_ptr<int>(), |
| 149 | + column_count.data_ptr<int>(), |
| 150 | + column_index.data_ptr<int>(), |
| 151 | + batch_size, |
| 152 | + num_heads, |
| 153 | + num_rows, |
| 154 | + nnz_vertical, |
| 155 | + nnz_slash |
| 156 | + ); |
142 | 157 |
|
143 | | - return {block_count, block_offset, column_count, column_index}; |
| 158 | + return { block_count, block_offset, column_count, column_index }; |
144 | 159 | } |
0 commit comments