@@ -6578,6 +6578,173 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
65786578        return  super ().modify_tensors (data_torch , name , bid )
65796579
65806580
6581+ @ModelBase .register ("Glm4MoeForCausalLM" ) 
6582+ class  Glm4MoeModel (TextModel ):
6583+     model_arch  =  gguf .MODEL_ARCH .GLM4_MOE 
6584+ 
6585+     def  set_vocab (self ):
6586+         from  transformers  import  AutoTokenizer 
6587+ 
6588+         tokenizer  =  AutoTokenizer .from_pretrained (
6589+             self .dir_model , trust_remote_code = True 
6590+         )
6591+         special_vocab  =  gguf .SpecialVocab (self .dir_model , load_merges = True )
6592+         tokens , toktypes , tokpre  =  self .get_vocab_base ()
6593+         self .gguf_writer .add_tokenizer_model ("gpt2" )
6594+         self .gguf_writer .add_tokenizer_pre (tokpre )
6595+         self .gguf_writer .add_token_list (tokens )
6596+         self .gguf_writer .add_token_types (toktypes )
6597+         special_vocab  =  gguf .SpecialVocab (self .dir_model , load_merges = True )
6598+         special_vocab ._set_special_token (
6599+             "eos" , tokenizer .get_added_vocab ()["<|endoftext|>" ]
6600+         )
6601+         special_vocab ._set_special_token ("eot" , tokenizer .get_added_vocab ()["<|user|>" ])
6602+         special_vocab ._set_special_token (
6603+             "unk" , tokenizer .get_added_vocab ()["<|endoftext|>" ]
6604+         )
6605+         special_vocab ._set_special_token (
6606+             "bos" , tokenizer .get_added_vocab ()["<|endoftext|>" ]
6607+         )
6608+         special_vocab .add_to_gguf (self .gguf_writer )
6609+ 
6610+     def  set_gguf_parameters (self ):
6611+         super ().set_gguf_parameters ()
6612+         if  (rope_dim  :=  self .hparams .get ("head_dim" )) is  None :
6613+             rope_dim  =  (
6614+                 self .hparams ["hidden_size" ] //  self .hparams ["num_attention_heads" ]
6615+             )
6616+         self .gguf_writer .add_rope_dimension_count (
6617+             int (rope_dim  *  self .hparams .get ("partial_rotary_factor" , 0.5 ))
6618+         )
6619+ 
6620+         # MoE parameters 
6621+         if  (n_experts  :=  self .hparams .get ("n_routed_experts" )) is  not None :
6622+             self .gguf_writer .add_expert_count (n_experts )
6623+         # Note: expert_used_count is already set by parent class using num_experts_per_tok 
6624+         if  (moe_intermediate_size  :=  self .hparams .get ("moe_intermediate_size" )) is  not None :
6625+             self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size )
6626+         if  (n_shared_experts  :=  self .hparams .get ("n_shared_experts" )) is  not None :
6627+             self .gguf_writer .add_expert_shared_count (n_shared_experts )
6628+         if  (first_k_dense_replace  :=  self .hparams .get ("first_k_dense_replace" )) is  not None :
6629+             self .gguf_writer .add_leading_dense_block_count (first_k_dense_replace )
6630+ 
6631+         # Expert gating function (sigmoid for GLM4_MOE) 
6632+         self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
6633+ 
6634+         # Routed scaling factor 
6635+         if  (routed_scaling_factor  :=  self .hparams .get ("routed_scaling_factor" )) is  not None :
6636+             self .gguf_writer .add_expert_weights_scale (routed_scaling_factor )
6637+ 
6638+         # Normalise topk probabilities 
6639+         if  (norm_topk_prob  :=  self .hparams .get ("norm_topk_prob" )) is  not None :
6640+             self .gguf_writer .add_expert_weights_norm (norm_topk_prob )
6641+ 
6642+     _experts : list [dict [str , Tensor ]] |  None  =  None 
6643+     _shared_experts : list [dict [str , Tensor ]] |  None  =  None 
6644+ 
6645+     def  modify_tensors (
6646+         self , data_torch : Tensor , name : str , bid : int  |  None 
6647+     ) ->  Iterable [tuple [str , Tensor ]]:
6648+         # Handle special GLM4_MOE layer 46 tensors (nextn prediction layer) 
6649+         if  bid  is  not None  and  bid  ==  46 :
6650+             # Layer 46 is the nextn prediction layer - skip all tensors 
6651+             return  []
6652+ 
6653+         if  name .startswith ("model.visual." ):  # ignore visual part 
6654+             return  []
6655+         elif  name .startswith ("model.language_model." ):
6656+             name  =  name .replace ("language_model." , "" )  # for multimodal variants 
6657+ 
6658+         # Handle routed experts 
6659+         if  name .find ("mlp.experts" ) !=  - 1  and  "shared_experts"  not  in name :
6660+             n_experts  =  self .hparams ["n_routed_experts" ]
6661+             assert  bid  is  not None 
6662+ 
6663+             if  self ._experts  is  None :
6664+                 self ._experts  =  [{} for  _  in  range (self .block_count )]
6665+ 
6666+             self ._experts [bid ][name ] =  data_torch 
6667+ 
6668+             if  len (self ._experts [bid ]) >=  n_experts  *  3 :
6669+                 tensors : list [tuple [str , Tensor ]] =  []
6670+ 
6671+                 # merge the experts into a single 3d tensor 
6672+                 for  w_name  in  ["down_proj" , "gate_proj" , "up_proj" ]:
6673+                     datas : list [Tensor ] =  []
6674+ 
6675+                     for  xid  in  range (n_experts ):
6676+                         ename  =  f"model.layers.{ bid } { xid } { w_name }  
6677+                         datas .append (self ._experts [bid ][ename ])
6678+                         del  self ._experts [bid ][ename ]
6679+ 
6680+                     data_torch  =  torch .stack (datas , dim = 0 )
6681+                     # Generate GGUF tensor names for merged experts 
6682+                     if  w_name  ==  "down_proj" :
6683+                         new_name  =  f"blk.{ bid }  
6684+                     elif  w_name  ==  "gate_proj" :
6685+                         new_name  =  f"blk.{ bid }  
6686+                     elif  w_name  ==  "up_proj" :
6687+                         new_name  =  f"blk.{ bid }  
6688+                     else :
6689+                         merged_name  =  f"model.layers.{ bid } { w_name }  
6690+                         new_name  =  self .map_tensor_name (merged_name )
6691+                     tensors .append ((new_name , data_torch ))
6692+                 return  tensors 
6693+             else :
6694+                 return  []
6695+ 
6696+         # Handle expert gating input (routing gate) 
6697+         if  ".mlp.gate.e_score_correction_bias"  in  name :
6698+             new_name  =  name .replace ("model.layers." , "blk." ).replace (
6699+                 ".mlp.gate.e_score_correction_bias" , ".ffn_gate_inp.bias" 
6700+             )
6701+             return  [(self .map_tensor_name (new_name ), data_torch )]
6702+ 
6703+         # Handle shared expert tensors 
6704+         if  ".mlp.ffn_"  in  name  and  "_shexp"  in  name :
6705+             new_name  =  name .replace ("model.layers." , "blk." )
6706+             return  [(new_name , data_torch )]
6707+ 
6708+         # Handle regular dense FFN layers (for hybrid dense/MoE architecture) 
6709+         if  ".mlp."  in  name  and  "experts"  not  in name  and  "_shexp"  not  in name :
6710+             if  "gate_proj"  in  name :
6711+                 new_name  =  name .replace ("model.layers." , "blk." ).replace (
6712+                     ".mlp.gate_proj.weight" , ".ffn_gate.weight" 
6713+                 )
6714+             elif  "up_proj"  in  name :
6715+                 new_name  =  name .replace ("model.layers." , "blk." ).replace (
6716+                     ".mlp.up_proj.weight" , ".ffn_up.weight" 
6717+                 )
6718+             elif  "down_proj"  in  name :
6719+                 new_name  =  name .replace ("model.layers." , "blk." ).replace (
6720+                     ".mlp.down_proj.weight" , ".ffn_down.weight" 
6721+                 )
6722+             else :
6723+                 new_name  =  name 
6724+             return  [(self .map_tensor_name (new_name ), data_torch )]
6725+ 
6726+         # Handle other special GLM4_MOE tensors (nextn prediction) 
6727+         if  (
6728+             ".embed_tokens."  in  name 
6729+             or  ".shared_head."  in  name 
6730+             or  ".eh_proj."  in  name 
6731+             or  ".enorm."  in  name 
6732+             or  ".hnorm."  in  name 
6733+         ):
6734+             # Skip these special tensors - they are for nextn prediction 
6735+             return  []
6736+ 
6737+         return  super ().modify_tensors (data_torch , name , bid )
6738+ 
6739+     def  prepare_tensors (self ):
6740+         super ().prepare_tensors ()
6741+         if  self ._experts  is  not None :
6742+             # flatten `list[dict[str, Tensor]]` into `list[str]` 
6743+             experts  =  [k  for  d  in  self ._experts  for  k  in  d .keys ()]
6744+             if  len (experts ) >  0 :
6745+                 raise  ValueError (f"Unprocessed experts: { experts }  )
6746+ 
6747+ 
65816748@ModelBase .register ("GlmForCausalLM" , "ChatGLMModel" , "ChatGLMForConditionalGeneration" ) 
65826749class  ChatGLMModel (TextModel ):
65836750    model_arch  =  gguf .MODEL_ARCH .CHATGLM 
0 commit comments