Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,28 @@ class Qwen25_3BConfig:
rope_scale = None
final_norm: bool = True

@dataclass
class Qwen3_06BConfig:
vocab_size: int = 151936
hidden_size: int = 1024
intermediate_size: int = 3072
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True

@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
Expand Down Expand Up @@ -641,6 +663,15 @@ def __init__(self, config_dict, dtype, device, operations):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Qwen3_06B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
self.num_layers = config.num_hidden_layers

self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
Expand Down