Skip to content

[Inference] Adapt to Fused rotary #5348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,10 @@ def forward(
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)

rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])

block_size = k_cache.size(-2)

if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
Expand All @@ -301,7 +300,7 @@ def forward(
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], k_cache, block_tables, sequence_lengths)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
Expand Down
1 change: 0 additions & 1 deletion colossalai/kernel/triton/kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def copy_kv_to_blocked_cache(
block_size = k_cache.size(-2)

num_warps = 8 if head_dim > 128 else 4

grid = (bsz, num_kv_heads)
_copy_to_kvcache_seqlen1_kernel[grid](
k,
Expand Down
136 changes: 128 additions & 8 deletions colossalai/kernel/triton/no_pad_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ def fused_rotary_embedding_kernel(
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim

past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1
past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1

last_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride)
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens))
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride

kv_range0 = (
Expand Down Expand Up @@ -274,6 +274,122 @@ def fused_rotary_embedding_kernel(
)


@triton.jit
def fused_rotary_embedding_kernel_v2(
q,
k,
cos,
sin,
kv_cache,
BLOCK_TABLES,
context_lengths,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_stride,
cacheb_stride,
cacheh_stride,
cachebs_stride,
cached_stride,
bts_stride,
btb_stride,
block_size,
q_total_tokens,
Q_HEAD_NUM: tl.constexpr,
K_HEAD_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
block_head_index = tl.program_id(0)
if block_head_index >= Q_HEAD_NUM:
return
block_token_index = tl.program_id(1)

dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)

off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride
off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride
off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride
off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride

loaded_q0 = tl.load(
q + off_q0,
)
loaded_q1 = tl.load(
q + off_q1,
)

loaded_k0 = tl.load(
k + off_k0,
)

loaded_k1 = tl.load(
k + off_k1,
)

off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride

loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)

out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos

out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim

past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1

last_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens))
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride

kv_range0 = (
block_ids * cacheb_stride
+ block_head_index * cacheh_stride
+ offsets_in_last_block
+ dim_range0 * cached_stride
)
kv_range1 = (
block_ids * cacheb_stride
+ block_head_index * cacheh_stride
+ offsets_in_last_block
+ dim_range1 * cached_stride
)

tl.store(
kv_cache + kv_range0,
out_k0,
)
tl.store(
kv_cache + kv_range1,
out_k1,
)

# concat
tl.store(
q + off_q0,
out_q0,
)
tl.store(
q + off_q1,
out_q1,
)
tl.store(
k + off_k0,
out_k0,
)
tl.store(
k + off_k1,
out_k1,
)


@torch.no_grad()
def rotary_embedding(
q: torch.Tensor,
k: torch.Tensor,
Expand All @@ -297,12 +413,13 @@ def rotary_embedding(
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 4
grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]))

if head_dim >= 256:
if head_dim >= 1024:
num_warps = 32
elif head_dim >= 128:
elif head_dim >= 512:
num_warps = 16
elif head_dim >= 256:
num_warps = 8
else:
num_warps = 4

Expand All @@ -318,6 +435,10 @@ def rotary_embedding(
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
if k_cache == None:
grid = lambda META: (
triton.cdiv(q_head_num, META["BLOCK_HEAD"]),
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
)
rotary_embedding_kernel[grid](
q,
k,
Expand All @@ -339,7 +460,8 @@ def rotary_embedding(
num_warps=num_warps,
)
else:
fused_rotary_embedding_kernel[grid](
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
fused_rotary_embedding_kernel_v2[grid](
q,
k,
cos,
Expand All @@ -365,8 +487,6 @@ def rotary_embedding(
Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num,
HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_TOKENS=BLOCK_TOKENS,
num_warps=num_warps,
)
return
1 change: 1 addition & 0 deletions examples/inference/run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
ROOT=$(realpath $(dirname $0))
echo $ROOT
PY_SCRIPT=${ROOT}/benchmark_llama.py
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
mode=$1
Expand Down
40 changes: 30 additions & 10 deletions tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb

from colossalai.kernel.triton import rotary_embedding
from colossalai.kernel.triton import copy_kv_to_blocked_cache, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2

try:
Expand Down Expand Up @@ -94,8 +94,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 11)],
line_arg="provider",
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
Expand All @@ -110,23 +110,43 @@ def benchmark_rotary_emb(
num_tokens: int,
num_kv_heads: int,
):
BATCH_SIZE = 4
SEQ_LEN = num_tokens // BATCH_SIZE
max_num_blocks_per_seq = 8
block_size = 64
warmup = 10
rep = 100

head_dim = 128
head_dim = 256
dtype = torch.float16

q_shape = (num_tokens, num_kv_heads, head_dim)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (num_tokens, num_kv_heads, head_dim)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (num_tokens, head_dim // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
v_cache = torch.zeros_like(k_cache)
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")

if provider == "torch_rotary_emb_func":
fn = lambda: torch_rotary_emb(q, cos, sin)
elif provider == "triton_rotary_emb_func":
fn = lambda: rotary_embedding(q, k, cos, sin)
if provider == "no_fused_rotary_emb_func":
fn = lambda: [
rotary_embedding(new_q, new_k, cos, sin),
copy_kv_to_blocked_cache(new_k, k_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables),
]
elif provider == "fused_triton_rotary_emb_func":
fn = lambda: rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths)
else:
raise ValueError("Undefined provider")

Expand All @@ -135,5 +155,5 @@ def benchmark_rotary_emb(


if __name__ == "__main__":
test_rotary_emb(4, 64, 32, 64, torch.float32)
# benchmark_rotary_emb.run(save_path=".",print_data=True)
# test_rotary_emb(4, 64, 32, 64, torch.float32)
benchmark_rotary_emb.run(save_path=".", print_data=True)