|  | 
| 3 | 3 | 
 | 
| 4 | 4 | #include <assert.h> | 
| 5 | 5 | 
 | 
|  | 6 | +#include <pybind11/numpy.h> | 
| 6 | 7 | #include <pybind11/pybind11.h> | 
| 7 | 8 | #include <pybind11/stl.h> | 
| 8 |  | -#include <pybind11/numpy.h> | 
| 9 | 9 | #include <torch/extension.h> | 
| 10 | 10 | 
 | 
| 11 | 11 | #include <cuda.h> | 
| 12 | 12 | 
 | 
| 13 |  | -__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { | 
| 14 |  | -    for (int idx = range_start; idx < range_end; idx += block_size) { | 
| 15 |  | -        block_offset[block_count++] = idx; | 
| 16 |  | -    } | 
|  | 13 | +__device__ void save_blocks(int *block_offset, int range_start, int range_end, | 
|  | 14 | +                            int block_size, int &block_count) { | 
|  | 15 | +  for (int idx = range_start; idx < range_end; idx += block_size) { | 
|  | 16 | +    block_offset[block_count++] = idx; | 
|  | 17 | +  } | 
| 17 | 18 | } | 
| 18 | 19 | 
 | 
| 19 | 20 | __global__ void convert_vertical_slash_indexes_kernel( | 
| 20 |  | -    const int* seqlens,           // [BATCH, ] | 
| 21 |  | -    const int* vertical_indexes,  // [BATCH, N_HEADS, NNZ_V] | 
| 22 |  | -    const int* slash_indexes,     // [BATCH, N_HEADS, NNZ_S] | 
| 23 |  | -    int* block_count,             // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | 
| 24 |  | -    int* block_offset,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] | 
| 25 |  | -    int* column_count,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | 
| 26 |  | -    int* column_index,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] | 
| 27 |  | -    int N_HEADS, | 
| 28 |  | -    int N_ROWS, | 
| 29 |  | -    int BLOCK_SIZE_M, | 
| 30 |  | -    int BLOCK_SIZE_N, | 
| 31 |  | -    int NNZ_V, | 
| 32 |  | -    int NNZ_S | 
| 33 |  | -) { | 
| 34 |  | -    const int batch_idx = blockIdx.y; | 
| 35 |  | -    const int head_idx = blockIdx.x; | 
| 36 |  | -    const int group_idx = blockIdx.z; | 
|  | 21 | +    const int *seqlens,          // [BATCH, ] | 
|  | 22 | +    const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V] | 
|  | 23 | +    const int *slash_indexes,    // [BATCH, N_HEADS, NNZ_S] | 
|  | 24 | +    int *block_count,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | 
|  | 25 | +    int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] | 
|  | 26 | +    int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | 
|  | 27 | +    int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] | 
|  | 28 | +    int N_HEADS, int N_ROWS, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int NNZ_V, | 
|  | 29 | +    int NNZ_S) { | 
|  | 30 | +  const int batch_idx = blockIdx.y; | 
|  | 31 | +  const int head_idx = blockIdx.x; | 
|  | 32 | +  const int group_idx = blockIdx.z; | 
| 37 | 33 | 
 | 
| 38 |  | -    int seqlen = seqlens[batch_idx]; | 
| 39 |  | -    int block_idx_m = group_idx * blockDim.x + threadIdx.x; | 
| 40 |  | -    int start_m = block_idx_m * BLOCK_SIZE_M; | 
| 41 |  | -    if (start_m >= seqlen) { | 
| 42 |  | -        return; | 
| 43 |  | -    } | 
| 44 |  | -    int end_m = start_m + BLOCK_SIZE_M; | 
| 45 |  | -    vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; | 
| 46 |  | -    slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; | 
| 47 |  | -    int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; | 
| 48 |  | -    block_count += row_offset; | 
| 49 |  | -    block_offset += row_offset * NNZ_S; | 
| 50 |  | -    column_count += row_offset; | 
| 51 |  | -    column_index += row_offset * NNZ_V; | 
|  | 34 | +  int seqlen = seqlens[batch_idx]; | 
|  | 35 | +  int block_idx_m = group_idx * blockDim.x + threadIdx.x; | 
|  | 36 | +  int start_m = block_idx_m * BLOCK_SIZE_M; | 
|  | 37 | +  if (start_m >= seqlen) { | 
|  | 38 | +    return; | 
|  | 39 | +  } | 
|  | 40 | +  int end_m = start_m + BLOCK_SIZE_M; | 
|  | 41 | +  vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; | 
|  | 42 | +  slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; | 
|  | 43 | +  int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; | 
|  | 44 | +  block_count += row_offset; | 
|  | 45 | +  block_offset += row_offset * NNZ_S; | 
|  | 46 | +  column_count += row_offset; | 
|  | 47 | +  column_index += row_offset * NNZ_V; | 
| 52 | 48 | 
 | 
| 53 |  | -    int tmp_col_cnt = 0, tmp_blk_cnt = 0; | 
| 54 |  | -    int s = 0, v = 0; | 
| 55 |  | -    int v_idx = vertical_indexes[v++]; | 
| 56 |  | -    int s_idx = slash_indexes[s++]; | 
| 57 |  | -    while (s_idx >= end_m) { | 
| 58 |  | -        s_idx = slash_indexes[s++]; | 
| 59 |  | -    } | 
| 60 |  | -    s_idx = max(end_m - s_idx, BLOCK_SIZE_M); | 
| 61 |  | -    int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; | 
| 62 |  | -    while (1) { | 
| 63 |  | -        if (v_idx < range_end) { | 
| 64 |  | -            if (v_idx < range_start) { | 
| 65 |  | -                column_index[tmp_col_cnt++] = v_idx; | 
| 66 |  | -            } | 
| 67 |  | -            if (v < NNZ_V) { | 
| 68 |  | -                v_idx = vertical_indexes[v++]; | 
| 69 |  | -            } else { | 
| 70 |  | -                v_idx = end_m + BLOCK_SIZE_M; | 
| 71 |  | -            } | 
| 72 |  | -        } else { | 
| 73 |  | -            if (s < NNZ_S) { | 
| 74 |  | -                s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); | 
| 75 |  | -            } else { | 
| 76 |  | -                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); | 
| 77 |  | -                break; | 
| 78 |  | -            } | 
| 79 |  | -            if (s_idx > range_end + BLOCK_SIZE_M) { | 
| 80 |  | -                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); | 
| 81 |  | -                range_start = s_idx - BLOCK_SIZE_M; | 
| 82 |  | -                range_end = s_idx; | 
| 83 |  | -            } else if (s_idx > range_end) { | 
| 84 |  | -                range_end += BLOCK_SIZE_M; | 
| 85 |  | -            } | 
| 86 |  | -        } | 
|  | 49 | +  int tmp_col_cnt = 0, tmp_blk_cnt = 0; | 
|  | 50 | +  int s = 0, v = 0; | 
|  | 51 | +  int v_idx = vertical_indexes[v++]; | 
|  | 52 | +  int s_idx = slash_indexes[s++]; | 
|  | 53 | +  while (s_idx >= end_m) { | 
|  | 54 | +    s_idx = slash_indexes[s++]; | 
|  | 55 | +  } | 
|  | 56 | +  s_idx = max(end_m - s_idx, BLOCK_SIZE_M); | 
|  | 57 | +  int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; | 
|  | 58 | +  while (1) { | 
|  | 59 | +    if (v_idx < range_end) { | 
|  | 60 | +      if (v_idx < range_start) { | 
|  | 61 | +        column_index[tmp_col_cnt++] = v_idx; | 
|  | 62 | +      } | 
|  | 63 | +      if (v < NNZ_V) { | 
|  | 64 | +        v_idx = vertical_indexes[v++]; | 
|  | 65 | +      } else { | 
|  | 66 | +        v_idx = end_m + BLOCK_SIZE_M; | 
|  | 67 | +      } | 
|  | 68 | +    } else { | 
|  | 69 | +      if (s < NNZ_S) { | 
|  | 70 | +        s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); | 
|  | 71 | +      } else { | 
|  | 72 | +        save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, | 
|  | 73 | +                    tmp_blk_cnt); | 
|  | 74 | +        break; | 
|  | 75 | +      } | 
|  | 76 | +      if (s_idx > range_end + BLOCK_SIZE_M) { | 
|  | 77 | +        save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, | 
|  | 78 | +                    tmp_blk_cnt); | 
|  | 79 | +        range_start = s_idx - BLOCK_SIZE_M; | 
|  | 80 | +        range_end = s_idx; | 
|  | 81 | +      } else if (s_idx > range_end) { | 
|  | 82 | +        range_end += BLOCK_SIZE_M; | 
|  | 83 | +      } | 
| 87 | 84 |     } | 
|  | 85 | +  } | 
| 88 | 86 | 
 | 
| 89 |  | -    block_count[0] = tmp_blk_cnt; | 
| 90 |  | -    column_count[0] = tmp_col_cnt; | 
|  | 87 | +  block_count[0] = tmp_blk_cnt; | 
|  | 88 | +  column_count[0] = tmp_col_cnt; | 
| 91 | 89 | } | 
| 92 | 90 | 
 | 
| 93 | 91 | void convert_vertical_slash_indexes_64x64( | 
| 94 |  | -    const int* seqlens,           // [BATCH, ] | 
| 95 |  | -    const int* vertical_indexes,  // [BATCH, N_HEADS, NNZ_V] | 
| 96 |  | -    const int* slash_indexes,     // [BATCH, N_HEADS, NNZ_S] | 
| 97 |  | -    int* block_count,             // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | 
| 98 |  | -    int* block_offset,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] | 
| 99 |  | -    int* column_count,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | 
| 100 |  | -    int* column_index,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] | 
| 101 |  | -    int BATCH_SIZE, | 
| 102 |  | -    int N_HEADS, | 
| 103 |  | -    int N_ROWS, | 
| 104 |  | -    int NNZ_V, | 
| 105 |  | -    int NNZ_S | 
| 106 |  | -) { | 
| 107 |  | -    const int BLOCK_SIZE_M = 64; | 
| 108 |  | -    const int BLOCK_SIZE_N = 64; | 
| 109 |  | -    const int N_THREADS = 64; | 
| 110 |  | -    const dim3 dimBlock(N_THREADS); | 
| 111 |  | -    const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); | 
| 112 |  | -    convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>( | 
| 113 |  | -        seqlens, vertical_indexes, slash_indexes, | 
| 114 |  | -        block_count, block_offset, column_count, column_index, | 
| 115 |  | -        N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S | 
| 116 |  | -    ); | 
|  | 92 | +    const int *seqlens,          // [BATCH, ] | 
|  | 93 | +    const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V] | 
|  | 94 | +    const int *slash_indexes,    // [BATCH, N_HEADS, NNZ_S] | 
|  | 95 | +    int *block_count,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | 
|  | 96 | +    int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] | 
|  | 97 | +    int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | 
|  | 98 | +    int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] | 
|  | 99 | +    int BATCH_SIZE, int N_HEADS, int N_ROWS, int NNZ_V, int NNZ_S) { | 
|  | 100 | +  const int BLOCK_SIZE_M = 64; | 
|  | 101 | +  const int BLOCK_SIZE_N = 64; | 
|  | 102 | +  const int N_THREADS = 64; | 
|  | 103 | +  const dim3 dimBlock(N_THREADS); | 
|  | 104 | +  const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); | 
|  | 105 | +  convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>( | 
|  | 106 | +      seqlens, vertical_indexes, slash_indexes, block_count, block_offset, | 
|  | 107 | +      column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, | 
|  | 108 | +      NNZ_V, NNZ_S); | 
| 117 | 109 | } | 
| 118 | 110 | 
 | 
| 119 | 111 | std::vector<at::Tensor> convert_vertical_slash_indexes( | 
| 120 |  | -    torch::Tensor seqlens,           // [BATCH, ] | 
| 121 |  | -    torch::Tensor vertical_indexes,  // [BATCH, N_HEADS, NNZ_V] | 
| 122 |  | -    torch::Tensor slash_indexes,     // [BATCH, N_HEADS, NNZ_S] | 
| 123 |  | -    int context_size, | 
| 124 |  | -    int block_size_M, | 
| 125 |  | -    int block_size_N | 
| 126 |  | -) { | 
| 127 |  | -    assert(block_size_M == 64); | 
| 128 |  | -    assert(block_size_N == 64); | 
|  | 112 | +    torch::Tensor seqlens,          // [BATCH, ] | 
|  | 113 | +    torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] | 
|  | 114 | +    torch::Tensor slash_indexes,    // [BATCH, N_HEADS, NNZ_S] | 
|  | 115 | +    int context_size, int block_size_M, int block_size_N) { | 
|  | 116 | +  assert(block_size_M == 64); | 
|  | 117 | +  assert(block_size_N == 64); | 
| 129 | 118 | 
 | 
| 130 |  | -    cudaSetDevice(seqlens.get_device()); | 
|  | 119 | +  cudaSetDevice(seqlens.get_device()); | 
| 131 | 120 | 
 | 
| 132 |  | -    int batch_size = slash_indexes.size(0); | 
| 133 |  | -    int num_heads = slash_indexes.size(1); | 
| 134 |  | -    int nnz_slash = slash_indexes.size(2); | 
| 135 |  | -    int nnz_vertical = vertical_indexes.size(2); | 
| 136 |  | -    int num_rows = (context_size + block_size_M - 1) / block_size_M; | 
|  | 121 | +  int batch_size = slash_indexes.size(0); | 
|  | 122 | +  int num_heads = slash_indexes.size(1); | 
|  | 123 | +  int nnz_slash = slash_indexes.size(2); | 
|  | 124 | +  int nnz_vertical = vertical_indexes.size(2); | 
|  | 125 | +  int num_rows = (context_size + block_size_M - 1) / block_size_M; | 
| 137 | 126 | 
 | 
| 138 |  | -    torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); | 
| 139 |  | -    torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); | 
| 140 |  | -    torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); | 
| 141 |  | -    torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); | 
|  | 127 | +  torch::Tensor block_count = | 
|  | 128 | +      torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); | 
|  | 129 | +  torch::Tensor block_offset = torch::zeros( | 
|  | 130 | +      {batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); | 
|  | 131 | +  torch::Tensor column_count = | 
|  | 132 | +      torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); | 
|  | 133 | +  torch::Tensor column_index = torch::zeros( | 
|  | 134 | +      {batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); | 
| 142 | 135 | 
 | 
| 143 |  | -    convert_vertical_slash_indexes_64x64( | 
| 144 |  | -        seqlens.data_ptr<int>(), | 
| 145 |  | -        vertical_indexes.data_ptr<int>(), | 
| 146 |  | -        slash_indexes.data_ptr<int>(), | 
| 147 |  | -        block_count.data_ptr<int>(), | 
| 148 |  | -        block_offset.data_ptr<int>(), | 
| 149 |  | -        column_count.data_ptr<int>(), | 
| 150 |  | -        column_index.data_ptr<int>(), | 
| 151 |  | -        batch_size, | 
| 152 |  | -        num_heads, | 
| 153 |  | -        num_rows, | 
| 154 |  | -        nnz_vertical, | 
| 155 |  | -        nnz_slash | 
| 156 |  | -    ); | 
|  | 136 | +  convert_vertical_slash_indexes_64x64( | 
|  | 137 | +      seqlens.data_ptr<int>(), vertical_indexes.data_ptr<int>(), | 
|  | 138 | +      slash_indexes.data_ptr<int>(), block_count.data_ptr<int>(), | 
|  | 139 | +      block_offset.data_ptr<int>(), column_count.data_ptr<int>(), | 
|  | 140 | +      column_index.data_ptr<int>(), batch_size, num_heads, num_rows, | 
|  | 141 | +      nnz_vertical, nnz_slash); | 
| 157 | 142 | 
 | 
| 158 |  | -    return { block_count, block_offset, column_count, column_index }; | 
|  | 143 | +  return {block_count, block_offset, column_count, column_index}; | 
| 159 | 144 | } | 
0 commit comments