Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Llama2Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True

@dataclass
class Qwen25_3BConfig:
Expand All @@ -53,6 +54,7 @@ class Qwen25_3BConfig:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True

@dataclass
class Qwen25_7BVLI_Config:
Expand All @@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True

@dataclass
class Gemma2_2B_Config:
Expand All @@ -96,6 +99,7 @@ class Gemma2_2B_Config:
k_norm = None
sliding_attention = None
rope_scale = None
final_norm: bool = True

@dataclass
class Gemma3_4B_Config:
Expand All @@ -118,6 +122,7 @@ class Gemma3_4B_Config:
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
final_norm: bool = True

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
Expand Down Expand Up @@ -366,7 +371,12 @@ def __init__(self, config, device=None, dtype=None, ops=None):
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)

if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None

# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)

def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
Expand Down Expand Up @@ -421,14 +431,16 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
if i == intermediate_output:
intermediate = x.clone()

x = self.norm(x)
if self.norm is not None:
x = self.norm(x)

if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())

if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)

if intermediate is not None and final_layer_norm_intermediate:
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)

return x, intermediate
Expand Down