Skip to content

Optimize data movement #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions cacheflow/models/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torch.nn as nn

from cacheflow import activation_ops


class SiluAndMul(nn.Module):

def __init__(self):
super().__init__()

def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out
92 changes: 46 additions & 46 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional
from typing import Optional

from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn

Expand All @@ -16,40 +16,38 @@ def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)

self.flash_attn = FlashAttention(softmax_scale=self.scale)

def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
prompt_lens: List[int],
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[2]
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')

device = query.device
prefix_sum = [0]
for prompt_len in prompt_lens:
prefix_sum.append(prefix_sum[-1] + prompt_len)
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
max_prompt_len = max(prompt_lens)

# FIXME(woosuk): Unnecessary copy. Optimize this.
qkv = torch.stack([query, key, value], dim=1)
out = self.flash_attn(
qkv,
cu_seqlens=prefix_sum,
max_s=max_prompt_len,
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
)[0]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output.copy_(out, non_blocking=True)
return_softmax=False,
)
Comment on lines +35 to +50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, so flash attention natively supports non-contiguous QKV tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It actually requires qkv tensor of shape [num_tokens, 3, num_heads, head_size]. Previously, we inserted torch.stack to meet this shape requirement, and this PR eliminates this inefficiency.


def single_query_cached_kv_attention(
self,
Expand Down Expand Up @@ -90,21 +88,18 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Pre-allocate the output tensor.
output = torch.empty_like(query)

# Prune out paddings if any.
query = query[:input_metadata.num_valid_tokens]
key = key[:input_metadata.num_valid_tokens]
value = value[:input_metadata.num_valid_tokens]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].

# Reshape the input tensors.
# Reshape the query, key, and value tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
output = output.view(-1, num_heads, head_size)

# Pre-allocate the output tensor.
output = torch.empty_like(query)

# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
Expand All @@ -114,22 +109,31 @@ def forward(
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.prompt_lens,
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
)

# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()

# Reshape the keys and values and store them in the cache.
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)
num_valid_tokens = input_metadata.num_valid_tokens
if num_valid_tokens > 0:
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
value[:num_valid_tokens],
key_cache,
value_cache,
input_metadata.slot_mapping,
)

if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:],
query[num_prompt_tokens:],
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens],
key_cache,
value_cache,
input_metadata)
Expand Down Expand Up @@ -186,19 +190,15 @@ def forward(
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions,
query,
key,
self.cos_sin_cache,
)
return super().forward(
out_query,
out_key,
query,
key,
value,
key_cache,
value_cache,
Expand Down
5 changes: 5 additions & 0 deletions cacheflow/models/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
Expand All @@ -20,13 +21,15 @@ def __init__(
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
Expand All @@ -40,11 +43,13 @@ def __repr__(self) -> str:
return (f'InputMetadata('
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len}), '
f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
f'slot_mapping={self.slot_mapping}, '
f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})')
19 changes: 7 additions & 12 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from transformers import LlamaConfig

from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
Expand Down Expand Up @@ -39,16 +40,14 @@ def __init__(
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()
if hidden_act != 'silu':
raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
gate, up = torch.split(gate_up, 1, dim=-2)
gate = gate.squeeze(dim=-2).contiguous()
up = up.squeeze(dim=-2).contiguous()
x = self.act_fn(gate) * up
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x

Expand Down Expand Up @@ -94,11 +93,7 @@ def forward(
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
Expand Down
7 changes: 2 additions & 5 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,14 @@ def forward(
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output


class OPTDecoderLayer(nn.Module):

def __init__(self, config: OPTConfig):
Expand Down
8 changes: 8 additions & 0 deletions cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def prepare_inputs(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

cumulative_prompt_lens: List[int] = [0]
for prompt_len in prompt_lens:
cumulative_prompt_lens.append(
cumulative_prompt_lens[-1] + prompt_len)

# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
Expand Down Expand Up @@ -183,11 +188,14 @@ def prepare_inputs(
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda')
cumulative_prompt_lens_tensor = torch.tensor(
cumulative_prompt_lens, dtype=torch.int, device='cuda')

input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens,
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
Expand Down
12 changes: 12 additions & 0 deletions csrc/activation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <torch/extension.h>

void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
}
46 changes: 46 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

namespace cacheflow {

template<typename T>
__device__ __forceinline__ T silu(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}

template<typename scalar_t>
__global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}

} // namespace cacheflow

void silu_and_mul(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, 2 * d]
{
int num_tokens = input.size(0);
int d = input.size(1) / 2;

dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
d);
});
}
Loading