@@ -143,7 +143,7 @@ def forward(
143143 query : torch .Tensor ,
144144 key : torch .Tensor ,
145145 value : torch .Tensor ,
146- kv_cache : Tuple [Optional [ torch .Tensor ], Optional [ torch .Tensor ] ],
146+ kv_cache : Tuple [torch .Tensor , torch .Tensor ],
147147 attn_metadata : PallasMetadata ,
148148 k_scale : float = 1.0 ,
149149 v_scale : float = 1.0 ,
@@ -155,8 +155,10 @@ def forward(
155155 query: shape = [batch_size, seq_len, num_heads * head_size]
156156 key: shape = [batch_size, seq_len, num_kv_heads * head_size]
157157 value: shape = [batch_size, seq_len, num_kv_heads * head_size]
158- key_cache = [num_kv_heads, num_blocks, block_size, head_size]
159- value_cache = [num_kv_heads, num_blocks, block_size, head_size]
158+ kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
159+ kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
160+ NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
161+ with shape [0] for profiling run.
160162 attn_metadata: Metadata for attention.
161163 Returns:
162164 shape = [batch_size, seq_len, num_heads * head_size]
@@ -173,7 +175,7 @@ def forward(
173175 value = value .view (batch_size , seq_len , self .num_kv_heads ,
174176 self .head_size )
175177
176- if kv_cache [0 ] is not None :
178+ if kv_cache [0 ]. numel () > 0 :
177179 slot_mapping = attn_metadata .slot_mapping
178180 key_cache , value_cache = kv_cache
179181 write_to_kv_cache (key , value , key_cache , value_cache , slot_mapping )
@@ -205,7 +207,7 @@ def forward(
205207 output = output .permute (0 , 2 , 1 , 3 )
206208 else :
207209 # Decoding run.
208- assert kv_cache is not None
210+ assert kv_cache [ 0 ]. numel () > 0
209211
210212 pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
211213 if self .megacore_mode == "batch" and batch_size % 2 != 0 :
0 commit comments