@@ -77,6 +77,28 @@ class Qwen25_3BConfig:
7777 rope_scale = None
7878 final_norm : bool = True
7979
80+ @dataclass
81+ class Qwen3_06BConfig :
82+ vocab_size : int = 151936
83+ hidden_size : int = 1024
84+ intermediate_size : int = 3072
85+ num_hidden_layers : int = 28
86+ num_attention_heads : int = 16
87+ num_key_value_heads : int = 8
88+ max_position_embeddings : int = 32768
89+ rms_norm_eps : float = 1e-6
90+ rope_theta : float = 1000000.0
91+ transformer_type : str = "llama"
92+ head_dim = 128
93+ rms_norm_add = False
94+ mlp_activation = "silu"
95+ qkv_bias = False
96+ rope_dims = None
97+ q_norm = "gemma3"
98+ k_norm = "gemma3"
99+ rope_scale = None
100+ final_norm : bool = True
101+
80102@dataclass
81103class Qwen3_4BConfig :
82104 vocab_size : int = 151936
@@ -641,6 +663,15 @@ def __init__(self, config_dict, dtype, device, operations):
641663 self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
642664 self .dtype = dtype
643665
666+ class Qwen3_06B (BaseLlama , torch .nn .Module ):
667+ def __init__ (self , config_dict , dtype , device , operations ):
668+ super ().__init__ ()
669+ config = Qwen3_06BConfig (** config_dict )
670+ self .num_layers = config .num_hidden_layers
671+
672+ self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
673+ self .dtype = dtype
674+
644675class Qwen3_4B (BaseLlama , torch .nn .Module ):
645676 def __init__ (self , config_dict , dtype , device , operations ):
646677 super ().__init__ ()
0 commit comments