@@ -892,8 +892,8 @@ def get_vocab_base_pre(self, tokenizer) -> str:
892892            # ref: https://huggingface.co/JetBrains/Mellum-4b-base 
893893            res  =  "mellum" 
894894        if  chkhsh  ==  "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206" :
895-             # ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base  
896-             res  =  "llada-moe " 
895+             # ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0  
896+             res  =  "bailingmoe2 " 
897897        if  chkhsh  ==  "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e" :
898898            # ref: https://huggingface.co/ibm-granite/granite-docling-258M 
899899            res  =  "granite-docling" 
@@ -8055,6 +8055,103 @@ def prepare_tensors(self):
80558055                raise  ValueError (f"Unprocessed experts: { experts }  )
80568056
80578057
8058+ @ModelBase .register ("BailingMoeV2ForCausalLM" ) 
8059+ class  BailingMoeV2Model (TextModel ):
8060+     model_arch  =  gguf .MODEL_ARCH .BAILINGMOE2 
8061+ 
8062+     def  __init__ (self , * args , ** kwargs ):
8063+         super ().__init__ (* args , ** kwargs )
8064+         if  nextn_layers  :=  self .hparams .get ("num_nextn_predict_layers" , 0 ):
8065+             self .block_count  =  self .hparams ["num_hidden_layers" ] +  nextn_layers 
8066+             self .tensor_map  =  gguf .get_tensor_name_map (self .model_arch , self .block_count )
8067+ 
8068+     def  set_vocab (self ):
8069+         self ._set_vocab_gpt2 ()
8070+ 
8071+     def  set_gguf_parameters (self ):
8072+         super ().set_gguf_parameters ()
8073+         hparams  =  self .hparams 
8074+         if  (rope_dim  :=  hparams .get ("head_dim" )) is  None :
8075+             rope_dim  =  hparams ["hidden_size" ] //  hparams ["num_attention_heads" ]
8076+ 
8077+         self .gguf_writer .add_rope_dimension_count (int (rope_dim  *  self .hparams .get ("partial_rotary_factor" , 0.5 )))
8078+         rope_scaling  =  self .hparams .get ("rope_scaling" ) or  {}
8079+         if  rope_scaling .get ("rope_type" , rope_scaling .get ("type" )) ==  "yarn"  and  "factor"  in  rope_scaling :
8080+             self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
8081+             self .gguf_writer .add_rope_scaling_factor (rope_scaling ["factor" ])
8082+             self .gguf_writer .add_rope_scaling_orig_ctx_len (rope_scaling ["original_max_position_embeddings" ])
8083+         else :
8084+             self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
8085+         self .gguf_writer .add_leading_dense_block_count (hparams ["first_k_dense_replace" ])
8086+         self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
8087+         self .gguf_writer .add_expert_feed_forward_length (hparams ["moe_intermediate_size" ])
8088+         self .gguf_writer .add_expert_shared_feed_forward_length (hparams .get ("moe_shared_expert_intermediate_size" , hparams ["moe_intermediate_size" ] *  hparams ["num_shared_experts" ]))
8089+         self .gguf_writer .add_expert_weights_scale (hparams ["routed_scaling_factor" ])
8090+         self .gguf_writer .add_expert_count (hparams ["num_experts" ])
8091+         self .gguf_writer .add_expert_shared_count (hparams ["num_shared_experts" ])
8092+         self .gguf_writer .add_expert_group_count (hparams ["n_group" ])
8093+         self .gguf_writer .add_expert_group_used_count (hparams ["topk_group" ])
8094+         self .gguf_writer .add_expert_weights_norm (hparams ["norm_topk_prob" ])
8095+ 
8096+         if  hparams ["score_function" ] ==  "sigmoid" :
8097+             self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
8098+         elif  hparams ["score_function" ] ==  "softmax" :
8099+             self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
8100+         else :
8101+             raise  ValueError (f"Unsupported score_function value: { hparams ['score_function' ]}  )
8102+ 
8103+         if  (nextn_layers  :=  self .hparams .get ("num_nextn_predict_layers" )) is  not None :
8104+             self .gguf_writer .add_nextn_predict_layers (nextn_layers )
8105+ 
8106+     _experts : list [dict [str , Tensor ]] |  None  =  None 
8107+ 
8108+     def  modify_tensors (self , data_torch : Tensor , name : str , bid : int  |  None ) ->  Iterable [tuple [str , Tensor ]]:
8109+         if  "mlp.experts"  in  name :
8110+             n_experts  =  self .hparams ["num_experts" ]
8111+             assert  bid  is  not None 
8112+ 
8113+             tensors : list [tuple [str , Tensor ]] =  []
8114+ 
8115+             if  self ._experts  is  None :
8116+                 self ._experts  =  [{} for  _  in  range (self .block_count )]
8117+ 
8118+             self ._experts [bid ][name ] =  data_torch 
8119+ 
8120+             if  len (self ._experts [bid ]) >=  n_experts  *  3 :
8121+                 # merge the experts into a single 3d tensor 
8122+                 for  w_name  in  ["down_proj" , "gate_proj" , "up_proj" ]:
8123+                     datas : list [Tensor ] =  []
8124+ 
8125+                     for  xid  in  range (n_experts ):
8126+                         ename  =  f"model.layers.{ bid } { xid } { w_name }  
8127+                         datas .append (self ._experts [bid ][ename ])
8128+                         del  self ._experts [bid ][ename ]
8129+ 
8130+                     data_torch  =  torch .stack (datas , dim = 0 )
8131+ 
8132+                     merged_name  =  f"model.layers.{ bid } { w_name }  
8133+ 
8134+                     new_name  =  self .map_tensor_name (merged_name )
8135+ 
8136+                     tensors .append ((new_name , data_torch ))
8137+ 
8138+             return  tensors 
8139+ 
8140+         if  name .endswith (".expert_bias" ):
8141+             name  =  name .replace (".expert_bias" , ".expert_bias.bias" )
8142+ 
8143+         return  [(self .map_tensor_name (name ), data_torch )]
8144+ 
8145+     def  prepare_tensors (self ):
8146+         super ().prepare_tensors ()
8147+ 
8148+         if  self ._experts  is  not None :
8149+             # flatten `list[dict[str, Tensor]]` into `list[str]` 
8150+             experts  =  [k  for  d  in  self ._experts  for  k  in  d .keys ()]
8151+             if  len (experts ) >  0 :
8152+                 raise  ValueError (f"Unprocessed experts: { experts }  )
8153+ 
8154+ 
80588155@ModelBase .register ("GroveMoeForCausalLM" , "modeling_grove_moe.GroveMoeForCausalLM" ) 
80598156class  GroveMoeModel (TextModel ):
80608157    model_arch  =  gguf .MODEL_ARCH .GROVEMOE 
0 commit comments