Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def llama_model_forward(
)
sm_scale = 1.0 / (batch.head_dim**0.5)

norm_output = torch.empty_like(hidden_states)

for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
Expand All @@ -107,13 +109,15 @@ def llama_model_forward(
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
)

if batch.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
hidden_states = self.norm(hidden_states)
norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states, norm_output)

return hidden_states

Expand All @@ -131,6 +135,7 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
Expand All @@ -148,11 +153,12 @@ def llama_decoder_layer_forward(
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states, norm_output)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
Expand All @@ -171,7 +177,7 @@ def llama_decoder_layer_forward(

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
hidden_states = self.mlp(hidden_states, residual)

return hidden_states
Expand Down
12 changes: 9 additions & 3 deletions colossalai/inference/modeling/models/padding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def llama_model_forward(
)
sm_scale = 1.0 / (batch.head_dim**0.5)

norm_output = torch.empty_like(hidden_states)

for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
Expand All @@ -149,12 +151,14 @@ def llama_model_forward(
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
)

if batch.is_prompts:
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
hidden_states = self.norm(hidden_states)
norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)

return hidden_states

Expand All @@ -174,6 +178,7 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
Expand All @@ -191,11 +196,12 @@ def llama_decoder_layer_forward(
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
Expand All @@ -217,7 +223,7 @@ def llama_decoder_layer_forward(

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/modeling/policy/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output)

return _triton_rmsnorm_forward
else:
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/modeling/policy/padding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_outpu: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_outpu)

return _triton_rmsnorm_forward
else:
Expand Down
10 changes: 4 additions & 6 deletions colossalai/kernel/triton/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,10 @@ def _rmsnorm_kernel(
tl.store(Y + cols, y.to(tl.float16), mask=mask)

@torch.no_grad()
def rms_layernorm(x, weight, eps):
def rms_layernorm(x, weight, eps, norm_output=None):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor, (total token, hidden_size)
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
y = torch.empty_like(x) if norm_output is None else norm_output
M, N = x.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()

Expand All @@ -67,5 +65,5 @@ def rms_layernorm(x, weight, eps):
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)

# enqueue kernel
_rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y
3 changes: 2 additions & 1 deletion examples/inference/benchmark_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.inference import InferenceEngine
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn

GIGABYTE = 1024**3
Expand Down
24 changes: 6 additions & 18 deletions examples/inference/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,10 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1

# benchmark llama2-7b one single GPU

for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt
done


for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt
done


for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt
done


for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt
for input_len in 128 512 1024; do
for output_len in 128 256; do
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
done
done
done