Skip to content

Commit 873de5f

Browse files
KV cache implementation for using llama models for text generation. (Comfy-Org#12195)
1 parent aa6f7a8 commit 873de5f

File tree

1 file changed

+74
-17
lines changed

1 file changed

+74
-17
lines changed

comfy/text_encoders/llama.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from dataclasses import dataclass
4-
from typing import Optional, Any
4+
from typing import Optional, Any, Tuple
55
import math
66

77
from comfy.ldm.modules.attention import optimized_attention_for_device
@@ -32,6 +32,7 @@ class Llama2Config:
3232
k_norm = None
3333
rope_scale = None
3434
final_norm: bool = True
35+
lm_head: bool = False
3536

3637
@dataclass
3738
class Mistral3Small24BConfig:
@@ -54,6 +55,7 @@ class Mistral3Small24BConfig:
5455
k_norm = None
5556
rope_scale = None
5657
final_norm: bool = True
58+
lm_head: bool = False
5759

5860
@dataclass
5961
class Qwen25_3BConfig:
@@ -76,6 +78,7 @@ class Qwen25_3BConfig:
7678
k_norm = None
7779
rope_scale = None
7880
final_norm: bool = True
81+
lm_head: bool = False
7982

8083
@dataclass
8184
class Qwen3_06BConfig:
@@ -98,6 +101,7 @@ class Qwen3_06BConfig:
98101
k_norm = "gemma3"
99102
rope_scale = None
100103
final_norm: bool = True
104+
lm_head: bool = False
101105

102106
@dataclass
103107
class Qwen3_4BConfig:
@@ -120,6 +124,7 @@ class Qwen3_4BConfig:
120124
k_norm = "gemma3"
121125
rope_scale = None
122126
final_norm: bool = True
127+
lm_head: bool = False
123128

124129
@dataclass
125130
class Qwen3_8BConfig:
@@ -142,6 +147,7 @@ class Qwen3_8BConfig:
142147
k_norm = "gemma3"
143148
rope_scale = None
144149
final_norm: bool = True
150+
lm_head: bool = False
145151

146152
@dataclass
147153
class Ovis25_2BConfig:
@@ -164,6 +170,7 @@ class Ovis25_2BConfig:
164170
k_norm = "gemma3"
165171
rope_scale = None
166172
final_norm: bool = True
173+
lm_head: bool = False
167174

168175
@dataclass
169176
class Qwen25_7BVLI_Config:
@@ -186,6 +193,7 @@ class Qwen25_7BVLI_Config:
186193
k_norm = None
187194
rope_scale = None
188195
final_norm: bool = True
196+
lm_head: bool = False
189197

190198
@dataclass
191199
class Gemma2_2B_Config:
@@ -209,6 +217,7 @@ class Gemma2_2B_Config:
209217
sliding_attention = None
210218
rope_scale = None
211219
final_norm: bool = True
220+
lm_head: bool = False
212221

213222
@dataclass
214223
class Gemma3_4B_Config:
@@ -232,6 +241,7 @@ class Gemma3_4B_Config:
232241
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
233242
rope_scale = [8.0, 1.0]
234243
final_norm: bool = True
244+
lm_head: bool = False
235245

236246
@dataclass
237247
class Gemma3_12B_Config:
@@ -255,6 +265,7 @@ class Gemma3_12B_Config:
255265
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
256266
rope_scale = [8.0, 1.0]
257267
final_norm: bool = True
268+
lm_head: bool = False
258269
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
259270
mm_tokens_per_image = 256
260271

@@ -356,6 +367,7 @@ def forward(
356367
attention_mask: Optional[torch.Tensor] = None,
357368
freqs_cis: Optional[torch.Tensor] = None,
358369
optimized_attention=None,
370+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
359371
):
360372
batch_size, seq_length, _ = hidden_states.shape
361373
xq = self.q_proj(hidden_states)
@@ -373,11 +385,30 @@ def forward(
373385

374386
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
375387

388+
present_key_value = None
389+
if past_key_value is not None:
390+
index = 0
391+
num_tokens = xk.shape[2]
392+
if len(past_key_value) > 0:
393+
past_key, past_value, index = past_key_value
394+
if past_key.shape[2] >= (index + num_tokens):
395+
past_key[:, :, index:index + xk.shape[2]] = xk
396+
past_value[:, :, index:index + xv.shape[2]] = xv
397+
xk = past_key[:, :, :index + xk.shape[2]]
398+
xv = past_value[:, :, :index + xv.shape[2]]
399+
present_key_value = (past_key, past_value, index + num_tokens)
400+
else:
401+
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
402+
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
403+
present_key_value = (xk, xv, index + num_tokens)
404+
else:
405+
present_key_value = (xk, xv, index + num_tokens)
406+
376407
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
377408
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
378409

379410
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
380-
return self.o_proj(output)
411+
return self.o_proj(output), present_key_value
381412

382413
class MLP(nn.Module):
383414
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
@@ -408,15 +439,17 @@ def forward(
408439
attention_mask: Optional[torch.Tensor] = None,
409440
freqs_cis: Optional[torch.Tensor] = None,
410441
optimized_attention=None,
442+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
411443
):
412444
# Self Attention
413445
residual = x
414446
x = self.input_layernorm(x)
415-
x = self.self_attn(
447+
x, present_key_value = self.self_attn(
416448
hidden_states=x,
417449
attention_mask=attention_mask,
418450
freqs_cis=freqs_cis,
419451
optimized_attention=optimized_attention,
452+
past_key_value=past_key_value,
420453
)
421454
x = residual + x
422455

@@ -426,7 +459,7 @@ def forward(
426459
x = self.mlp(x)
427460
x = residual + x
428461

429-
return x
462+
return x, present_key_value
430463

431464
class TransformerBlockGemma2(nn.Module):
432465
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
@@ -451,6 +484,7 @@ def forward(
451484
attention_mask: Optional[torch.Tensor] = None,
452485
freqs_cis: Optional[torch.Tensor] = None,
453486
optimized_attention=None,
487+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
454488
):
455489
if self.transformer_type == 'gemma3':
456490
if self.sliding_attention:
@@ -468,11 +502,12 @@ def forward(
468502
# Self Attention
469503
residual = x
470504
x = self.input_layernorm(x)
471-
x = self.self_attn(
505+
x, present_key_value = self.self_attn(
472506
hidden_states=x,
473507
attention_mask=attention_mask,
474508
freqs_cis=freqs_cis,
475509
optimized_attention=optimized_attention,
510+
past_key_value=past_key_value,
476511
)
477512

478513
x = self.post_attention_layernorm(x)
@@ -485,7 +520,7 @@ def forward(
485520
x = self.post_feedforward_layernorm(x)
486521
x = residual + x
487522

488-
return x
523+
return x, present_key_value
489524

490525
class Llama2_(nn.Module):
491526
def __init__(self, config, device=None, dtype=None, ops=None):
@@ -516,9 +551,10 @@ def __init__(self, config, device=None, dtype=None, ops=None):
516551
else:
517552
self.norm = None
518553

519-
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
554+
if config.lm_head:
555+
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
520556

521-
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
557+
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
522558
if embeds is not None:
523559
x = embeds
524560
else:
@@ -527,8 +563,13 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
527563
if self.normalize_in:
528564
x *= self.config.hidden_size ** 0.5
529565

566+
seq_len = x.shape[1]
567+
past_len = 0
568+
if past_key_values is not None and len(past_key_values) > 0:
569+
past_len = past_key_values[0][2]
570+
530571
if position_ids is None:
531-
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
572+
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
532573

533574
freqs_cis = precompute_freqs_cis(self.config.head_dim,
534575
position_ids,
@@ -539,14 +580,16 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
539580

540581
mask = None
541582
if attention_mask is not None:
542-
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
583+
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
543584
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
544585

545-
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
546-
if mask is not None:
547-
mask += causal_mask
548-
else:
549-
mask = causal_mask
586+
if seq_len > 1:
587+
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
588+
if mask is not None:
589+
mask += causal_mask
590+
else:
591+
mask = causal_mask
592+
550593
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
551594

552595
intermediate = None
@@ -562,16 +605,27 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
562605
elif intermediate_output < 0:
563606
intermediate_output = len(self.layers) + intermediate_output
564607

608+
next_key_values = []
565609
for i, layer in enumerate(self.layers):
566610
if all_intermediate is not None:
567611
if only_layers is None or (i in only_layers):
568612
all_intermediate.append(x.unsqueeze(1).clone())
569-
x = layer(
613+
614+
past_kv = None
615+
if past_key_values is not None:
616+
past_kv = past_key_values[i] if len(past_key_values) > 0 else []
617+
618+
x, current_kv = layer(
570619
x=x,
571620
attention_mask=mask,
572621
freqs_cis=freqs_cis,
573622
optimized_attention=optimized_attention,
623+
past_key_value=past_kv,
574624
)
625+
626+
if current_kv is not None:
627+
next_key_values.append(current_kv)
628+
575629
if i == intermediate_output:
576630
intermediate = x.clone()
577631

@@ -588,7 +642,10 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
588642
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
589643
intermediate = self.norm(intermediate)
590644

591-
return x, intermediate
645+
if len(next_key_values) > 0:
646+
return x, intermediate, next_key_values
647+
else:
648+
return x, intermediate
592649

593650

594651
class Gemma3MultiModalProjector(torch.nn.Module):

0 commit comments

Comments
 (0)