44import torch
55from transformers .models .llama .modeling_llama import LlamaAttention , LlamaDecoderLayer , LlamaForCausalLM , LlamaModel
66
7+ from colossalai .inference .flash_decoding_utils import FDIntermTensors
78from colossalai .inference .modeling .layers .attention import PagedAttention
89from colossalai .inference .struct import BatchInfo
910from colossalai .kernel .triton import (
@@ -50,15 +51,13 @@ def llama_causal_lm_forward(
5051 batch : BatchInfo = None ,
5152 k_caches : List [torch .Tensor ] = None ,
5253 v_caches : List [torch .Tensor ] = None ,
53- padding_id : int = None ,
5454):
5555 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
5656 hidden_states = llama_model_forward (
5757 self .model ,
5858 batch = batch ,
5959 k_caches = k_caches ,
6060 v_caches = v_caches ,
61- padding_id = padding_id ,
6261 )
6362 logits = self .lm_head (hidden_states )
6463 return logits
@@ -70,11 +69,10 @@ def llama_model_forward(
7069 batch : BatchInfo = None ,
7170 k_caches : List [torch .Tensor ] = None ,
7271 v_caches : List [torch .Tensor ] = None ,
73- padding_id : int = None ,
7472):
7573 input_ids = batch .get_batch_inputs ()
7674 block_tables = batch .get_block_table_tensor ()
77- attention_mask = batch .get_attn_mask (padding_id )
75+ attention_mask = batch .get_attn_mask ()
7876
7977 if attention_mask is not None :
8078 if HAS_TRITON :
@@ -84,6 +82,7 @@ def llama_model_forward(
8482 else :
8583 sequence_lengths = batch .get_sequence_lengths ()
8684
85+ batch_size , _ = input_ids .shape
8786 kv_seq_len = sequence_lengths .max ().item ()
8887
8988 if attention_mask is not None :
@@ -102,7 +101,22 @@ def llama_model_forward(
102101
103102 hidden_states = self .embed_tokens (input_ids )
104103
105- cos_sin = get_cos_sin (sequence_lengths , self ._cos_cached , self ._sin_cached , batch .is_prompts , hidden_states .dtype )
104+ # When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
105+ # cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
106+ # sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
107+ # cos_sin = (cos, sin)
108+
109+ cos_sin = get_cos_sin (sequence_lengths , self ._cos_cached , self ._sin_cached , batch .is_prompts , batch .dtype )
110+
111+ if batch .is_prompts :
112+ output_tensor = torch .zeros (
113+ (sequence_lengths .sum ().item (), batch .num_heads , batch .head_dim ), dtype = batch .dtype , device = batch .device
114+ )
115+ else :
116+ output_tensor = torch .zeros (
117+ (batch_size , 1 , batch .num_heads , batch .head_dim ), dtype = batch .dtype , device = batch .device
118+ )
119+ sm_scale = 1.0 / (batch .head_dim ** 0.5 )
106120
107121 for layer_id , decoder_layer in enumerate (self .layers ):
108122 hidden_states = decoder_layer (
@@ -116,6 +130,9 @@ def llama_model_forward(
116130 attention_mask = attention_mask ,
117131 kv_seq_len = kv_seq_len ,
118132 cos_sin = cos_sin ,
133+ fd_inter_tensor = batch .fd_inter_tensor ,
134+ output_tensor = output_tensor ,
135+ sm_scale = sm_scale ,
119136 )
120137
121138 hidden_states = self .norm (hidden_states )
@@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
131148 k_cache : torch .Tensor = None ,
132149 v_cache : torch .Tensor = None ,
133150 is_prompts : bool = True ,
134- sequence_lengths : int = None ,
151+ sequence_lengths : torch . Tensor = None ,
135152 attention_mask : torch .Tensor = None ,
136153 kv_seq_len : int = 0 ,
137154 cos_sin : Tuple [torch .Tensor ] = None ,
155+ fd_inter_tensor : FDIntermTensors = None ,
156+ output_tensor : torch .Tensor = None ,
157+ sm_scale : int = None ,
138158) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
139159 residual = hidden_states
140160
@@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
151171 attention_mask = attention_mask ,
152172 kv_seq_len = kv_seq_len ,
153173 cos_sin = cos_sin ,
174+ fd_inter_tensor = fd_inter_tensor ,
175+ output_tensor = output_tensor ,
176+ sm_scale = sm_scale ,
154177 )
155178
156179 hidden_states = residual + hidden_states
@@ -178,6 +201,9 @@ def llama_attn_forward(
178201 attention_mask : torch .Tensor = None ,
179202 kv_seq_len : int = 0 ,
180203 cos_sin : Tuple [torch .Tensor ] = None ,
204+ fd_inter_tensor : FDIntermTensors = None ,
205+ output_tensor : torch .Tensor = None ,
206+ sm_scale : int = None ,
181207) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
182208 bsz , q_len , _ = hidden_states .size ()
183209
@@ -206,15 +232,35 @@ def llama_attn_forward(
206232
207233 if is_prompts :
208234 attn_output = context_attention_unpadded (
209- query_states , key_states , value_states , k_cache , v_cache , sequence_lengths , block_tables , block_size
235+ q = query_states ,
236+ k = key_states ,
237+ v = value_states ,
238+ k_cache = k_cache ,
239+ v_cache = v_cache ,
240+ context_lengths = sequence_lengths ,
241+ block_tables = block_tables ,
242+ block_size = block_size ,
243+ output = output_tensor ,
244+ max_seq_len = kv_seq_len ,
245+ sm_scale = sm_scale ,
210246 )
211247 if attention_mask is not None :
212248 attn_output = pad_input (attn_output , indices , bsz , q_len )
213249 else :
214250 copy_kv_to_blocked_cache (key_states , k_cache , kv_lengths = sequence_lengths , block_tables = block_tables )
215251 copy_kv_to_blocked_cache (value_states , v_cache , kv_lengths = sequence_lengths , block_tables = block_tables )
216252 attn_output = flash_decoding_attention (
217- query_states , k_cache , v_cache , sequence_lengths , block_tables , block_size
253+ q = query_states ,
254+ k_cache = k_cache ,
255+ v_cache = v_cache ,
256+ kv_seq_len = sequence_lengths ,
257+ block_tables = block_tables ,
258+ block_size = block_size ,
259+ max_seq_len_in_batch = kv_seq_len ,
260+ output = output_tensor ,
261+ mid_output = fd_inter_tensor .mid_output ,
262+ mid_output_lse = fd_inter_tensor .mid_output_lse ,
263+ sm_scale = sm_scale ,
218264 )
219265 attn_output = attn_output .squeeze (1 )
220266 else :
@@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
285331
286332@torch .no_grad ()
287333def get_cos_sin (lengths , cos_cache , sin_cache , is_prompts , dtype ):
334+ """
335+ Get cos and sin for the cache, and return nopad format.
336+ Args:
337+ lengths: shape(num_seqs,), stores lenghth of each sequence.
338+ cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
339+ sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
340+ is_prompts: bool, mark if in prefill mode.
341+ dtype: The data type of this inference process.
342+ """
343+
288344 if is_prompts :
289345 index_arrays = [torch .arange (length ) for length in lengths ]
290346 else :
0 commit comments