Skip to content

Commit 249644c

Browse files
authored
[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)
* add fused qkv * replace attn and mlp by shardformer * fix bugs in mlp * add docstrings * fix test_inference_engine.py * add optimize unbind * add fused_addmm * rm squeeze(1) * refactor codes * fix ci bugs * rename ShardFormerLlamaMLP and ShardFormerLlamaAttention * Removed the dependency on LlamaFlashAttention2 * rollback test_inference_engine.py
1 parent f8e456d commit 249644c

File tree

8 files changed

+510
-341
lines changed

8 files changed

+510
-341
lines changed

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 232 additions & 72 deletions
Large diffs are not rendered by default.

colossalai/inference/modeling/models/padding_llama.py

Lines changed: 223 additions & 100 deletions
Large diffs are not rendered by default.

colossalai/inference/modeling/policy/nopadding_llama.py

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
from functools import partial
22

33
import torch
4-
from transformers.models.llama.modeling_llama import (
5-
LlamaAttention,
6-
LlamaDecoderLayer,
7-
LlamaFlashAttention2,
8-
LlamaForCausalLM,
9-
LlamaMLP,
10-
LlamaModel,
11-
LlamaRMSNorm,
12-
LlamaSdpaAttention,
13-
)
4+
from torch.nn import Parameter
5+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
146

157
from colossalai.inference.modeling.models.nopadding_llama import (
16-
llama_attn_forward,
8+
NopadLlamaAttention,
9+
NopadLlamaMLP,
1710
llama_causal_lm_forward,
1811
llama_decoder_layer_forward,
1912
llama_model_forward,
20-
nopad_mlp,
2113
)
2214
from colossalai.inference.utils import init_to_get_rotary
15+
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
2316

2417
# import colossalai
2518
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@@ -50,6 +43,27 @@ def __init__(self) -> None:
5043

5144
def module_policy(self):
5245
policy = super().module_policy()
46+
47+
decoder_attribute_replacement = {
48+
"lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False),
49+
}
50+
policy[LlamaForCausalLM] = ModulePolicyDescription(
51+
attribute_replacement=decoder_attribute_replacement,
52+
)
53+
54+
policy[LlamaDecoderLayer] = ModulePolicyDescription(
55+
sub_module_replacement=[
56+
SubModuleReplacementDescription(
57+
suffix="mlp",
58+
target_module=NopadLlamaMLP,
59+
),
60+
SubModuleReplacementDescription(
61+
suffix="self_attn",
62+
target_module=NopadLlamaAttention,
63+
),
64+
]
65+
)
66+
5367
self.shard_config._infer()
5468

5569
infer_forward = llama_causal_lm_forward
@@ -68,28 +82,6 @@ def module_policy(self):
6882
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
6983
)
7084

71-
infer_forward = nopad_mlp
72-
method_replacement = {"forward": partial(infer_forward)}
73-
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP)
74-
75-
infer_forward = llama_attn_forward
76-
method_replacement = {"forward": partial(infer_forward)}
77-
self.append_or_create_method_replacement(
78-
description=method_replacement, policy=policy, target_key=LlamaAttention
79-
)
80-
81-
infer_forward = llama_attn_forward
82-
method_replacement = {"forward": partial(infer_forward)}
83-
self.append_or_create_method_replacement(
84-
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
85-
)
86-
87-
infer_forward = llama_attn_forward
88-
method_replacement = {"forward": partial(infer_forward)}
89-
self.append_or_create_method_replacement(
90-
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
91-
)
92-
9385
infer_forward = None
9486
if HAS_TRITON_RMSNORM:
9587
infer_forward = get_triton_rmsnorm_forward()

colossalai/inference/modeling/policy/padding_llama.py

Lines changed: 10 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
from functools import partial
22

33
import torch
4-
from transformers.models.llama.modeling_llama import (
5-
LlamaAttention,
6-
LlamaDecoderLayer,
7-
LlamaFlashAttention2,
8-
LlamaForCausalLM,
9-
LlamaModel,
10-
LlamaRMSNorm,
11-
LlamaSdpaAttention,
12-
)
4+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
135

146
from colossalai.inference.modeling.models.padding_llama import (
15-
llama_attn_forward,
7+
PadLlamaAttention,
168
llama_causal_lm_forward,
179
llama_decoder_layer_forward,
1810
llama_model_forward,
@@ -49,105 +41,16 @@ def __init__(self) -> None:
4941

5042
def module_policy(self):
5143
policy = super().module_policy()
52-
decoder_attribute_replacement = {
53-
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
54-
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
55-
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
56-
// self.shard_config.tensor_parallel_size,
57-
}
58-
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
59-
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
60-
61-
policy[LlamaDecoderLayer] = ModulePolicyDescription(
62-
attribute_replacement=decoder_attribute_replacement,
63-
sub_module_replacement=[
64-
SubModuleReplacementDescription(
65-
suffix="self_attn.q_proj",
66-
target_module=ColCaiQuantLinear,
67-
kwargs={"split_num": 1},
68-
),
69-
SubModuleReplacementDescription(
70-
suffix="self_attn.k_proj",
71-
target_module=ColCaiQuantLinear,
72-
kwargs={"split_num": 1},
73-
),
74-
SubModuleReplacementDescription(
75-
suffix="self_attn.v_proj",
76-
target_module=ColCaiQuantLinear,
77-
kwargs={"split_num": 1},
78-
),
79-
SubModuleReplacementDescription(
80-
suffix="self_attn.o_proj",
81-
target_module=RowCaiQuantLinear,
82-
kwargs={"split_num": 1},
83-
),
84-
SubModuleReplacementDescription(
85-
suffix="mlp.gate_proj",
86-
target_module=ColCaiQuantLinear,
87-
kwargs={"split_num": 1},
88-
),
89-
SubModuleReplacementDescription(
90-
suffix="mlp.up_proj",
91-
target_module=ColCaiQuantLinear,
92-
kwargs={"split_num": 1},
93-
),
94-
SubModuleReplacementDescription(
95-
suffix="mlp.down_proj",
96-
target_module=RowCaiQuantLinear,
97-
kwargs={"split_num": 1},
98-
),
99-
],
100-
)
10144

102-
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
103-
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
104-
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
105-
ColW8A8BFP32OFP32Linear,
106-
RowW8A8B8O8Linear,
107-
RowW8A8BFP32O32LinearSiLU,
108-
RowW8A8BFP32OFP32Linear,
109-
)
45+
policy[LlamaDecoderLayer] = ModulePolicyDescription(
46+
sub_module_replacement=[
47+
SubModuleReplacementDescription(
48+
suffix="self_attn",
49+
target_module=PadLlamaAttention,
50+
),
51+
]
52+
)
11053

111-
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
112-
attribute_replacement=decoder_attribute_replacement,
113-
sub_module_replacement=[
114-
SubModuleReplacementDescription(
115-
suffix="self_attn.q_proj",
116-
target_module=RowW8A8B8O8Linear,
117-
kwargs={"split_num": 1},
118-
),
119-
SubModuleReplacementDescription(
120-
suffix="self_attn.k_proj",
121-
target_module=RowW8A8B8O8Linear,
122-
kwargs={"split_num": 1},
123-
),
124-
SubModuleReplacementDescription(
125-
suffix="self_attn.v_proj",
126-
target_module=RowW8A8B8O8Linear,
127-
kwargs={"split_num": 1},
128-
),
129-
SubModuleReplacementDescription(
130-
suffix="self_attn.o_proj",
131-
target_module=ColW8A8BFP32OFP32Linear,
132-
kwargs={"split_num": 1},
133-
),
134-
SubModuleReplacementDescription(
135-
suffix="mlp.gate_proj",
136-
target_module=RowW8A8BFP32O32LinearSiLU,
137-
kwargs={"split_num": 1},
138-
),
139-
SubModuleReplacementDescription(
140-
suffix="mlp.up_proj",
141-
target_module=RowW8A8BFP32OFP32Linear,
142-
kwargs={"split_num": 1},
143-
),
144-
SubModuleReplacementDescription(
145-
suffix="mlp.down_proj",
146-
target_module=ColW8A8BFP32OFP32Linear,
147-
kwargs={"split_num": 1},
148-
),
149-
],
150-
)
15154
self.shard_config._infer()
15255

15356
infer_forward = llama_causal_lm_forward
@@ -166,24 +69,6 @@ def module_policy(self):
16669
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
16770
)
16871

169-
infer_forward = llama_attn_forward
170-
method_replacement = {"forward": partial(infer_forward)}
171-
self.append_or_create_method_replacement(
172-
description=method_replacement, policy=policy, target_key=LlamaAttention
173-
)
174-
175-
infer_forward = llama_attn_forward
176-
method_replacement = {"forward": partial(infer_forward)}
177-
self.append_or_create_method_replacement(
178-
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
179-
)
180-
181-
infer_forward = llama_attn_forward
182-
method_replacement = {"forward": partial(infer_forward)}
183-
self.append_or_create_method_replacement(
184-
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
185-
)
186-
18772
infer_forward = None
18873
if HAS_TRITON_RMSNORM:
18974
infer_forward = get_triton_rmsnorm_forward()

colossalai/kernel/triton/flash_decoding.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ def _flash_decoding_fwd_reduce_kernel(
143143
stride_o_lset,
144144
stride_o_lseh,
145145
stride_o_lseb,
146-
stride_ob,
147-
stride_ol,
146+
stride_ot,
148147
stride_oh,
149148
stride_od,
150149
BLOCK_KV: tl.constexpr,
@@ -180,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel(
180179
m_i = m_ij
181180

182181
acc = acc / l
183-
offsets_O = cur_seq_idx * stride_ob + cur_head_idx * stride_oh + offsets_dmodel
182+
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
184183
tl.store(O + offsets_O, acc.to(O.type.element_ty))
185184
return
186185

@@ -212,7 +211,7 @@ def flash_decoding_attention(
212211
records the (kv) sequence lengths incorporating past kv sequence lengths.
213212
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
214213
max_seq_len_in_batch (int): Maximum sequence length in the batch.
215-
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
214+
output (torch.Tensor): [bsz, num_heads, head_dim]
216215
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
217216
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
218217
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@@ -294,7 +293,7 @@ def flash_decoding_attention(
294293
HEAD_DIM=head_dim,
295294
)
296295

297-
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
296+
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
298297

299298
grid = (triton.next_power_of_2(bsz), num_heads)
300299

@@ -314,7 +313,6 @@ def flash_decoding_attention(
314313
output.stride(0),
315314
output.stride(1),
316315
output.stride(2),
317-
output.stride(3),
318316
BLOCK_KV=block_size,
319317
HEAD_DIM=head_dim,
320318
)

examples/inference/run_benchmark.sh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,20 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
2525
# benchmark llama2-7b one single GPU
2626

2727
for bsz in 16 32 64; do
28-
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
28+
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt
2929
done
3030

3131

3232
for bsz in 16 32 64; do
33-
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
33+
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt
34+
done
35+
36+
37+
for bsz in 16 32 64; do
38+
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt
39+
done
40+
41+
42+
for bsz in 16 32 64; do
43+
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt
3444
done

tests/test_infer_ops/triton/kernel_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def torch_attn_ref(
6969
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
7070
)
7171
out = out.transpose(1, 2).contiguous()
72+
out = out.squeeze(1)
7273
return out
7374

7475

tests/test_infer_ops/triton/test_decoding_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_flash_decoding(
9494
max_seq_len_in_b = kv_seq_lengths.max().item()
9595
# The maximum block length splitted on kv should be the kv cache block size
9696
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
97-
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
97+
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
9898
mid_output = torch.empty(
9999
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
100100
)
@@ -189,7 +189,7 @@ def bench_kernel(
189189
block_tables = block_tables.to(device=device)
190190
# the maximum block length splitted on kv should be the kv cache block size
191191
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
192-
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
192+
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
193193
mid_output = torch.empty(
194194
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
195195
)

0 commit comments

Comments
 (0)