Skip to content

Memcpy kernel for flash attention #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions benchmark/benchmark_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import functools
import random
import time

import torch

from cacheflow import cache_ops


def benchmark(name, f, size: int, num_warmup = 10, num_iters = 100):
for _ in range(num_warmup):
f()
torch.cuda.synchronize()

start = time.time()
for _ in range(num_iters):
f()
torch.cuda.synchronize()
end = time.time()
avg_time = (end - start) / num_iters
print(f'[Latency] {name}: {avg_time * 1000:.3f} ms')
print(f'[Throughput] {name}: {size / avg_time / 2 ** 30:.3f} GB/s')


@torch.inference_mode()
def test_gather_cached_kv(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
print(f'num_tokens: {num_tokens}, num_heads: {num_heads}, '
f'head_size: {head_size}, block_size: {block_size}, '
f'num_blocks: {num_blocks}, dtype: {dtype}')

num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)

x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')

value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')

# Run Flash attention.
def run():
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)

benchmark('gather_cached_kv', run,
size=num_tokens * num_heads * head_size * 2 * qkv.element_size())


if __name__ == '__main__':
BLOCK_SIZE = 8
NUM_BLOCKS = 1024
DTYPE = torch.half

# LLaMA-13B and OPT-13B
NUM_HEADS = 40
HEAD_SIZE = 128

run_benchmark = functools.partial(
test_gather_cached_kv,
num_heads=NUM_HEADS,
head_size=HEAD_SIZE,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
dtype=DTYPE,
)

for i in range(6, 12):
run_benchmark(num_tokens=2 ** i)
11 changes: 11 additions & 0 deletions csrc/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ void reshape_and_cache(
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);

void gather_cached_kv(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"swap_blocks",
Expand All @@ -33,4 +40,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
m.def(
"gather_cached_kv",
&gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors");
}
157 changes: 157 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,124 @@ __global__ void reshape_and_cache_kernel(
}
}

// Grid: (num_blocks, block_size).
template<typename scalar_t>
__global__ void gather_cached_kv_kernel(
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x) {
const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx];
const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size;

const int num_tokens = num_heads * head_size;
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
const int tgt_key_idx = token_idx * key_stride + i;
const int tgt_value_idx = token_idx * value_stride + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
const int x_offset = head_offset % x;

const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int src_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;

key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
}
}

template <typename scalar_t>
__global__ void gather_cached_kv_kernel_optimized(
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int *__restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x)
{
const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx];
const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size;

const int dim = num_heads * head_size;
assert(dim % 4 == 0); // this is true for known use cases
const int unroll_factor = 4;
const int unrolled_dim = dim / unroll_factor;

for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
{
int tgt_key_indices[unroll_factor];
int tgt_value_indices[unroll_factor];
int src_key_indices[unroll_factor];
int src_value_indices[unroll_factor];
scalar_t keys_to_store[unroll_factor];
scalar_t values_to_store[unroll_factor];

#pragma unroll
for (int j = 0; j < unroll_factor; ++j)
{
int index = i + j * unrolled_dim;

const int tgt_key_idx = token_idx * key_stride + index;
const int tgt_value_idx = token_idx * value_stride + index;

const int head_idx = index / head_size;
const int head_offset = index % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;

const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int src_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;

tgt_key_indices[j] = tgt_key_idx;
tgt_value_indices[j] = tgt_value_idx;
src_key_indices[j] = src_key_idx;
src_value_indices[j] = src_value_idx;

keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
}

#pragma unroll
for (int j = 0; j < unroll_factor; ++j)
{
key[tgt_key_indices[j]] = keys_to_store[j];
value[tgt_value_indices[j]] = values_to_store[j];
}
}
}

} // namespace cacheflow

void reshape_and_cache(
Expand Down Expand Up @@ -215,3 +333,42 @@ void reshape_and_cache(
x);
});
}


void gather_cached_kv(
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [in] [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);

int key_stride = key.stride(0);
int value_stride = value.stride(0);

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {
cacheflow::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
44 changes: 44 additions & 0 deletions tests/kernels/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,47 @@ def test_reshape_and_cache(
assert torch.allclose(value_cache, cloned_value_cache)


def test_gather_cached_kv(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)

qkv_clone = qkv.clone()
_, cloned_key, cloned_value = qkv_clone.unbind(dim=1)

x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')

value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')

cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)

# Reference implementation.
for i in range(num_tokens):
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
cloned_value[i] = value_cache[block_idx, :, :, block_offset]

assert torch.allclose(key, cloned_key)
assert torch.allclose(value, cloned_value)


@torch.inference_mode()
def test_cache() -> None:
test_copy_blocks(
Expand All @@ -107,6 +148,9 @@ def test_cache() -> None:
test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half)
test_gather_cached_kv(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half)


if __name__ == '__main__':
Expand Down