Skip to content

[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. #5365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ def _shardformer(
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.

Returns:
nn.Module: _description_
nn.Module: The model optimized by Shardformer.
"""

shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
Expand Down Expand Up @@ -149,25 +150,25 @@ def generate(
Returns:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)

self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)

output_seqs_list = []
output_tokens_list = []
output_seqs_list = []
output_tokens_list = []

while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()

output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))

for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)

output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)

return output_str
return output_str

def add_request(
self,
Expand Down
9 changes: 0 additions & 9 deletions colossalai/inference/modeling/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from transformers.modeling_attn_mask_utils import AttentionMaskConverter


@torch.no_grad
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
"""
Func: copy key/value into key/value cache.
Expand Down Expand Up @@ -41,7 +40,6 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache


@torch.no_grad
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation
Expand Down Expand Up @@ -81,7 +79,6 @@ class PagedAttention:
"""

@staticmethod
@torch.no_grad
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
"""
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
Expand All @@ -97,14 +94,12 @@ def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
return padded_tensor

@staticmethod
@torch.no_grad
def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask

@staticmethod
@torch.no_grad
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
"""
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
Expand All @@ -122,7 +117,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)

@staticmethod
@torch.no_grad
def nopad_context_forward(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
Expand Down Expand Up @@ -191,7 +185,6 @@ def nopad_context_forward(
return attn_output

@staticmethod
@torch.no_grad
def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
Expand Down Expand Up @@ -249,7 +242,6 @@ def pad_context_forward(
return attn_output

@staticmethod
@torch.no_grad
def pad_decoding_forward(
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
Expand Down Expand Up @@ -306,7 +298,6 @@ def pad_decoding_forward(
return attn_output

@staticmethod
@torch.no_grad
def no_pad_decoding_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]
Expand Down
32 changes: 14 additions & 18 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")


@torch.no_grad()
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
Expand All @@ -58,7 +57,6 @@ def llama_causal_lm_forward(
return logits


@torch.no_grad()
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
Expand Down Expand Up @@ -120,7 +118,6 @@ def llama_model_forward(
return hidden_states


@torch.no_grad()
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
Expand All @@ -139,7 +136,7 @@ def llama_decoder_layer_forward(
"""This function will replace the forward function of LlamaDecoderLayer.

Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
Expand All @@ -154,8 +151,8 @@ def llama_decoder_layer_forward(
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states, norm_output)
# Self Attention
hidden_states = self.self_attn(
Expand Down Expand Up @@ -240,7 +237,6 @@ def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttentio
return attn_layer

# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -258,8 +254,8 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
Expand Down Expand Up @@ -321,7 +317,7 @@ def forward(
sm_scale=sm_scale,
)

attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)

return attn_output
Expand All @@ -345,9 +341,10 @@ def __init__(
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__(config)
self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False)
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
self.gate_proj = None
self.up_proj = None

@staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
Expand All @@ -371,15 +368,14 @@ def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:

return mlp_layer

@torch.no_grad()
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
"""
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
tmp_out = act_out * up_proj_out
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
tmp_out = act_out * gate_up_proj_out[1]
return torch.addmm(residual, tmp_out, self.down_proj.weight)
Loading