Skip to content

Commit 4f28cb4

Browse files
authored
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
1 parent af8359c commit 4f28cb4

File tree

16 files changed

+199
-57
lines changed

16 files changed

+199
-57
lines changed

colossalai/inference/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class InferenceConfig:
5555
def __post_init__(self):
5656
self._init_batch_size()
5757
self._verify_config()
58+
self._get_dtype()
5859

5960
def _init_batch_size(self):
6061
"""
@@ -84,6 +85,7 @@ def _verify_config(self) -> None:
8485
assert (
8586
self.tp_size * self.pp_size == dist.get_world_size()
8687
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
88+
8789
assert self.dtype in [
8890
"fp16",
8991
"fp32",
@@ -97,3 +99,11 @@ def _verify_config(self) -> None:
9799
"gptq",
98100
None,
99101
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
102+
103+
def _get_dtype(self) -> None:
104+
if self.dtype == "fp32" or self.dtype == torch.float32:
105+
self.dtype = torch.float32
106+
elif self.dtype == "fp16" or self.dtype == torch.float16:
107+
self.dtype = torch.float16
108+
else:
109+
self.dtype = torch.bfloat16

colossalai/inference/core/engine.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,10 @@ def __init__(
5151
self.inference_config = inference_config
5252
self.model_config = model.config
5353
self.device = torch.device("cuda")
54+
self.dtype = inference_config.dtype
5455

5556
model = model.eval()
56-
57-
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
58-
self.dtype = torch.float32
59-
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
60-
self.dtype = torch.float16
61-
model.half()
62-
else:
63-
self.dtype = torch.bfloat16
64-
model.to(torch.bfloat16)
57+
model.to(self.dtype)
6558

6659
if model_policy is None:
6760
model_policy = model_policy_map[self.model_config.model_type]()
@@ -217,6 +210,7 @@ def add_request(
217210
None,
218211
block_table,
219212
self.tokenizer.eos_token_id,
213+
self.tokenizer.pad_token_id,
220214
self.inference_config.max_output_len,
221215
)
222216
self.request_handler.add_sequence(sequence)
@@ -241,7 +235,6 @@ def step(self) -> List[str]:
241235
batch,
242236
self.k_cahce,
243237
self.v_cache,
244-
padding_id=self.tokenizer.pad_token_id,
245238
)
246239

247240
logits = logits[:, -1, :]

colossalai/inference/core/request_handler.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from transformers.configuration_utils import PretrainedConfig
55

66
from colossalai.inference.config import InferenceConfig
7+
from colossalai.inference.flash_decoding_utils import FDIntermTensors
78
from colossalai.inference.kv_cache import KVCacheManager
89
from colossalai.inference.logit_processors import logit_processor
910
from colossalai.inference.sampler import *
@@ -69,20 +70,60 @@ class RequestHandler:
6970
Args:
7071
inference_config: Configuration for initialize and manage kv cache.
7172
model_config: Configuration for model
73+
dtype (torch.dtype): The data type for weights and activations.
7274
"""
7375

7476
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
7577
self.inference_config = inference_config
76-
self._init_cache(model_config)
77-
7878
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
7979
self.waiting_list: List[List] = [[], [], []]
8080
self.done_list: List[Sequence] = []
81-
device = torch.cuda.current_device()
82-
self.running_batch = BatchInfo(is_prompts=False, device=device)
83-
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
81+
self.dtype = inference_config.dtype
8482
self.max_batch_size = inference_config.max_batch_size
8583

84+
# initialize cache
85+
self._init_cache(model_config)
86+
87+
# initialize batch
88+
device = torch.cuda.current_device()
89+
kv_max_split_num = (
90+
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
91+
) // inference_config.block_size
92+
head_dim = model_config.hidden_size // model_config.num_attention_heads
93+
94+
fd_inter_tensor = FDIntermTensors()
95+
fd_inter_tensor.initialize(
96+
max_batch_size=self.max_batch_size,
97+
num_attn_heads=model_config.num_attention_heads,
98+
kv_max_split_num=kv_max_split_num,
99+
head_dim=head_dim,
100+
dtype=self.dtype,
101+
device=device,
102+
)
103+
104+
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
105+
# which may cause bugs and this issue should be fixed later.
106+
self.running_batch = BatchInfo(
107+
max_batch_size=self.max_batch_size,
108+
kv_max_split_num=kv_max_split_num,
109+
num_heads=model_config.num_attention_heads,
110+
head_dim=head_dim,
111+
is_prompts=False,
112+
device=device,
113+
dtype=self.dtype,
114+
fd_inter_tensor=fd_inter_tensor,
115+
)
116+
self.prefill_batch = BatchInfo(
117+
max_batch_size=self.max_batch_size,
118+
kv_max_split_num=kv_max_split_num,
119+
num_heads=model_config.num_attention_heads,
120+
head_dim=head_dim,
121+
is_prompts=True,
122+
device=device,
123+
dtype=self.dtype,
124+
fd_inter_tensor=fd_inter_tensor,
125+
)
126+
86127
def _init_cache(self, model_config):
87128
self.cache_manager = KVCacheManager(self.inference_config, model_config)
88129

colossalai/inference/kv_cache/kvcache_manager.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,7 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
5858
# Parallel settings
5959
self.tp_size = config.tp_size
6060
# Model settings
61-
if config.dtype == "fp32" or config.dtype == torch.float32:
62-
self.dtype = torch.float32
63-
elif config.dtype == "fp16" or config.dtype == torch.float16:
64-
self.dtype = torch.float16
65-
else:
66-
self.dtype = torch.bfloat16
61+
self.dtype = config.dtype
6762
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
6863
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
6964
# For now we focus on MHA only, TODO add handling for MQA and GQA

colossalai/inference/modeling/models/llama.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
66

7+
from colossalai.inference.flash_decoding_utils import FDIntermTensors
78
from colossalai.inference.modeling.layers.attention import PagedAttention
89
from colossalai.inference.struct import BatchInfo
910
from colossalai.kernel.triton import (
@@ -50,15 +51,13 @@ def llama_causal_lm_forward(
5051
batch: BatchInfo = None,
5152
k_caches: List[torch.Tensor] = None,
5253
v_caches: List[torch.Tensor] = None,
53-
padding_id: int = None,
5454
):
5555
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
5656
hidden_states = llama_model_forward(
5757
self.model,
5858
batch=batch,
5959
k_caches=k_caches,
6060
v_caches=v_caches,
61-
padding_id=padding_id,
6261
)
6362
logits = self.lm_head(hidden_states)
6463
return logits
@@ -70,11 +69,10 @@ def llama_model_forward(
7069
batch: BatchInfo = None,
7170
k_caches: List[torch.Tensor] = None,
7271
v_caches: List[torch.Tensor] = None,
73-
padding_id: int = None,
7472
):
7573
input_ids = batch.get_batch_inputs()
7674
block_tables = batch.get_block_table_tensor()
77-
attention_mask = batch.get_attn_mask(padding_id)
75+
attention_mask = batch.get_attn_mask()
7876

7977
if attention_mask is not None:
8078
if HAS_TRITON:
@@ -84,6 +82,7 @@ def llama_model_forward(
8482
else:
8583
sequence_lengths = batch.get_sequence_lengths()
8684

85+
batch_size, _ = input_ids.shape
8786
kv_seq_len = sequence_lengths.max().item()
8887

8988
if attention_mask is not None:
@@ -102,7 +101,22 @@ def llama_model_forward(
102101

103102
hidden_states = self.embed_tokens(input_ids)
104103

105-
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
104+
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
105+
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
106+
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
107+
# cos_sin = (cos, sin)
108+
109+
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
110+
111+
if batch.is_prompts:
112+
output_tensor = torch.zeros(
113+
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
114+
)
115+
else:
116+
output_tensor = torch.zeros(
117+
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
118+
)
119+
sm_scale = 1.0 / (batch.head_dim**0.5)
106120

107121
for layer_id, decoder_layer in enumerate(self.layers):
108122
hidden_states = decoder_layer(
@@ -116,6 +130,9 @@ def llama_model_forward(
116130
attention_mask=attention_mask,
117131
kv_seq_len=kv_seq_len,
118132
cos_sin=cos_sin,
133+
fd_inter_tensor=batch.fd_inter_tensor,
134+
output_tensor=output_tensor,
135+
sm_scale=sm_scale,
119136
)
120137

121138
hidden_states = self.norm(hidden_states)
@@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
131148
k_cache: torch.Tensor = None,
132149
v_cache: torch.Tensor = None,
133150
is_prompts: bool = True,
134-
sequence_lengths: int = None,
151+
sequence_lengths: torch.Tensor = None,
135152
attention_mask: torch.Tensor = None,
136153
kv_seq_len: int = 0,
137154
cos_sin: Tuple[torch.Tensor] = None,
155+
fd_inter_tensor: FDIntermTensors = None,
156+
output_tensor: torch.Tensor = None,
157+
sm_scale: int = None,
138158
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
139159
residual = hidden_states
140160

@@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
151171
attention_mask=attention_mask,
152172
kv_seq_len=kv_seq_len,
153173
cos_sin=cos_sin,
174+
fd_inter_tensor=fd_inter_tensor,
175+
output_tensor=output_tensor,
176+
sm_scale=sm_scale,
154177
)
155178

156179
hidden_states = residual + hidden_states
@@ -178,6 +201,9 @@ def llama_attn_forward(
178201
attention_mask: torch.Tensor = None,
179202
kv_seq_len: int = 0,
180203
cos_sin: Tuple[torch.Tensor] = None,
204+
fd_inter_tensor: FDIntermTensors = None,
205+
output_tensor: torch.Tensor = None,
206+
sm_scale: int = None,
181207
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
182208
bsz, q_len, _ = hidden_states.size()
183209

@@ -206,15 +232,35 @@ def llama_attn_forward(
206232

207233
if is_prompts:
208234
attn_output = context_attention_unpadded(
209-
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
235+
q=query_states,
236+
k=key_states,
237+
v=value_states,
238+
k_cache=k_cache,
239+
v_cache=v_cache,
240+
context_lengths=sequence_lengths,
241+
block_tables=block_tables,
242+
block_size=block_size,
243+
output=output_tensor,
244+
max_seq_len=kv_seq_len,
245+
sm_scale=sm_scale,
210246
)
211247
if attention_mask is not None:
212248
attn_output = pad_input(attn_output, indices, bsz, q_len)
213249
else:
214250
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
215251
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
216252
attn_output = flash_decoding_attention(
217-
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
253+
q=query_states,
254+
k_cache=k_cache,
255+
v_cache=v_cache,
256+
kv_seq_len=sequence_lengths,
257+
block_tables=block_tables,
258+
block_size=block_size,
259+
max_seq_len_in_batch=kv_seq_len,
260+
output=output_tensor,
261+
mid_output=fd_inter_tensor.mid_output,
262+
mid_output_lse=fd_inter_tensor.mid_output_lse,
263+
sm_scale=sm_scale,
218264
)
219265
attn_output = attn_output.squeeze(1)
220266
else:
@@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
285331

286332
@torch.no_grad()
287333
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
334+
"""
335+
Get cos and sin for the cache, and return nopad format.
336+
Args:
337+
lengths: shape(num_seqs,), stores lenghth of each sequence.
338+
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
339+
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
340+
is_prompts: bool, mark if in prefill mode.
341+
dtype: The data type of this inference process.
342+
"""
343+
288344
if is_prompts:
289345
index_arrays = [torch.arange(length) for length in lengths]
290346
else:

0 commit comments

Comments
 (0)