Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Optional, Tuple

import torch
from torch.nn import Parameter
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
Expand Down Expand Up @@ -82,19 +81,21 @@ def llama_model_forward(

if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
(batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)

norm_output = torch.empty_like(hidden_states)
residual = None

for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states, residual = decoder_layer(
hidden_states,
residual=residual,
block_tables=block_tables,
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
Expand All @@ -111,15 +112,17 @@ def llama_model_forward(
if batch.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states, norm_output)
hidden_states, _ = self.norm(hidden_states, norm_output, residual)

return hidden_states


def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
residual: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
Expand All @@ -136,6 +139,7 @@ def llama_decoder_layer_forward(

Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
Expand All @@ -151,12 +155,10 @@ def llama_decoder_layer_forward(
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states, norm_output)
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
residual=residual,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
Expand All @@ -170,11 +172,10 @@ def llama_decoder_layer_forward(
)

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
hidden_states = self.mlp(hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
hidden_states = self.mlp(hidden_states)

return hidden_states
return hidden_states, residual


class NopadLlamaAttention(LlamaAttention):
Expand All @@ -198,16 +199,18 @@ def __init__(
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
"""
super().__init__(config, layer_idx)
self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False)
self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False)
self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False)
self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False)
self.q_proj_weight = attn_qproj_w
self.k_proj_weight = attn_kproj_w
self.v_proj_weight = attn_vproj_w
self.o_proj_weight = attn_oproj_w

if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]
qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
self.q_proj = None
self.k_proj = None
self.v_proj = None

self.q_proj = None
self.k_proj = None
self.v_proj = None

@staticmethod
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
Expand Down Expand Up @@ -239,7 +242,6 @@ def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttentio
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
Expand All @@ -254,7 +256,6 @@ def forward(
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
Expand All @@ -270,9 +271,9 @@ def forward(
"""

if self.num_heads != self.num_key_value_heads:
query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
else:
# fused qkv
token_nums = hidden_states.size(0)
Expand Down Expand Up @@ -317,8 +318,7 @@ def forward(
sm_scale=sm_scale,
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
attn_output = torch.mm(attn_output, self.o_proj_weight)

return attn_output

Expand All @@ -341,10 +341,11 @@ def __init__(
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__(config)
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False)
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
self.down_proj_weight = mlp_dproj_w
self.gate_proj = None
self.up_proj = None
self.down_proj = None

@staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
Expand All @@ -368,14 +369,13 @@ def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:

return mlp_layer

def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
tmp_out = act_out * gate_up_proj_out[1]
return torch.addmm(residual, tmp_out, self.down_proj.weight)
return torch.mm(tmp_out, self.down_proj_weight)
6 changes: 4 additions & 2 deletions colossalai/inference/modeling/policy/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output)
def _triton_rmsnorm_forward(
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)

return _triton_rmsnorm_forward
else:
Expand Down
10 changes: 6 additions & 4 deletions colossalai/kernel/triton/context_attn_unpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,17 @@ def context_attention_unpadded(
assert k_cache.shape == v_cache.shape
assert context_lengths.shape[0] == block_tables.shape[0]

num_tokens, num_heads, _ = q.shape
num_tokens, num_heads, head_dim = q.shape
num_kv_heads = k.shape[-2]
assert num_kv_heads > 0 and num_heads % num_kv_heads == 0
num_kv_group = num_heads // num_kv_heads

num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
output = torch.zeros_like(q) if output is None else output
output = (
torch.empty((num_tokens, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
)

# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
# the size of physical cache block (i.e. `block_size`)
Expand Down Expand Up @@ -243,8 +245,8 @@ def context_attention_unpadded(
v.stride(1),
v.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
head_dim,
1,
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
Expand Down
12 changes: 6 additions & 6 deletions colossalai/kernel/triton/flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def flash_decoding_attention(
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, num_heads, head_dim]
output (torch.Tensor): [bsz, num_heads * head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
Expand All @@ -220,7 +220,7 @@ def flash_decoding_attention(
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.

Returns:
Output tensor with shape [bsz, num_heads, head_dim]
Output tensor with shape [bsz, num_heads * head_dim]
"""
q = q.squeeze() if q.dim() == 4 else q
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
Expand Down Expand Up @@ -261,7 +261,7 @@ def flash_decoding_attention(
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output

_flash_decoding_fwd_kernel[grid](
q,
Expand Down Expand Up @@ -294,7 +294,7 @@ def flash_decoding_attention(
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)

grid = (triton.next_power_of_2(bsz), num_heads)

_flash_decoding_fwd_reduce_kernel[grid](
Expand All @@ -311,8 +311,8 @@ def flash_decoding_attention(
mid_output_lse.stride(1),
mid_output_lse.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
head_dim,
1,
BLOCK_KV=block_size,
HEAD_DIM=head_dim,
)
Expand Down
54 changes: 51 additions & 3 deletions colossalai/kernel/triton/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,50 @@ def _rmsnorm_kernel(
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)

def rms_layernorm(x, weight, eps, norm_output=None):
@triton.jit
def _rmsnorm_with_residual_kernel(
X, # pointer to the input
Y, # pointer to the output
R, # pointer to the residual
W, # pointer to the weights
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).

# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
Y += row * stride
X += row * stride
R += row * stride
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x, 0.0)
r = tl.load(R + cols, mask=cols < N, other=0.0).to(tl.float32)
r = tl.where(cols < N, r, 0.0)
x = x + r
_var += x * x
mask = cols < N
tl.store(X + cols, x.to(tl.float16), mask=mask)
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)

def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
# allocate output
y = torch.empty_like(x) if norm_output is None else norm_output
M, N = x.shape
Expand All @@ -64,5 +107,10 @@ def rms_layernorm(x, weight, eps, norm_output=None):
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)

# enqueue kernel
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y
if residual is None:
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
else:
_rmsnorm_with_residual_kernel[(M,)](
x, y, residual, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
return y, x
3 changes: 2 additions & 1 deletion examples/inference/benchmark_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def benchmark_inference(args):
else:
assert args.model_path, "When testing pretrained weights, the model path must be provided.'"
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda()
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")

model = model.eval()

Expand All @@ -122,6 +122,7 @@ def benchmark_inference(args):
elif args.mode == "vllm":
engine = LLM(
model=args.model_path,
tokenizer="hf-internal-testing/llama-tokenizer",
max_num_seqs=mbsz,
dtype="float16",
enforce_eager=True,
Expand Down
4 changes: 4 additions & 0 deletions tests/test_infer/test_ops/triton/test_context_attn_unpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,14 @@ def test_context_attention(
k_cache_triton = torch.zeros_like(k_cache_ref)
v_cache_triton = torch.zeros_like(v_cache_ref)

_, num_heads, head_dim = q_unpad.shape

out_triton = context_attention_unpadded(
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
)

out_triton = out_triton.view(-1, num_heads, head_dim)

out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)

assert out_torch.shape == out_triton.shape
Expand Down
Loading