Skip to content

model : add hunyuan moe #14425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model name better with hunyuan A13B


if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -6436,6 +6439,155 @@ def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])


@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(TextModel):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you align with Hunyuan's naming , with version V1 suffix?

model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also could you add the version suffix on the arch name, like the arch name in model 's config.json ?


def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# For handling tied embeddings
self._tok_embd = None

def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)

# 1. Get the pre-tokenizer identifier hash
tokpre = self.get_vocab_base_pre(tokenizer)

# 2. Reverse-engineer the merges list from mergeable_ranks
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2: # todo this is an assert in Qwen, why?
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))

# 3. Generate the tokens and toktypes lists
vocab_size = self.hparams["vocab_size"]
assert tokenizer.vocab_size == vocab_size
special_tokens = tokenizer.special_tokens
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
tokens: list[str] = []
toktypes: list[int] = []
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)

# 4. Write all vocab-related fields to the GGUF writer
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)

# 5. Add special tokens and chat templates
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
# FIX for BOS token: Overwrite incorrect id read from config.json
self.gguf_writer.add_bos_token_id(127959) # <|bos|>

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams

self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])

moe_intermediate_size = hparams["moe_intermediate_size"]
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])

moe_topk = hparams["moe_topk"]
assert all(topk == moe_topk[0] for topk in moe_topk)
self.gguf_writer.add_expert_used_count(moe_topk[0])

moe_shared_expert = hparams["num_shared_expert"]
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])

# Rope
rope_scaling = hparams.get("rope_scaling", {})
if rope_scaling.get("type") == "dynamic":
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
alpha = rope_scaling.get("alpha", 1000)
base = hparams.get("rope_theta", 10000.0)
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
self.gguf_writer.add_rope_freq_base(scaled_base)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(1)
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
self.gguf_writer.add_context_length(256 * 1024) # 256k context length

# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "model.embed_tokens.weight":
self._tok_embd = data_torch.clone()

if name == "lm_head.weight":
if self.hparams.get("tie_word_embeddings", False):
logger.info("Skipping tied output layer 'lm_head.weight'")
return []

if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
tensors: list[tuple[str, Tensor]] = []
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))

return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model name should be hunyuan a13b, from my source , they will release more llm model soon, we'd better add some identify for the mdoel.

]


Expand Down
23 changes: 23 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ class MODEL_ARCH(IntEnum):
DOTS1 = auto()
ARCEE = auto()
ERNIE4_5 = auto()
HUNYUAN_MOE = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -656,6 +657,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hunyuan-moe-v1 will be a better name for later model updating.

}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -2193,6 +2195,27 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.HUNYUAN_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
# TODO
}

Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"model.layers.{bid}.feed_forward.router", # llama4
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan
),

MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
Expand Down Expand Up @@ -362,6 +363,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
),

# AWQ-activation gate
Expand Down Expand Up @@ -398,6 +400,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
),

# Feed-forward down
Expand Down Expand Up @@ -447,11 +450,13 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
),

MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
Expand All @@ -461,6 +466,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a version suffix on vocab type will be better.

};

enum llama_rope_type {
Expand Down
24 changes: 24 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also .

{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1676,6 +1677,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_HUNYUAN_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ enum llm_arch {
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_UNKNOWN,
};

Expand Down
18 changes: 18 additions & 0 deletions src/llama-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
};

llm_chat_template llm_chat_template_from_str(const std::string & name) {
Expand Down Expand Up @@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
Expand Down Expand Up @@ -665,6 +668,21 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|response|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
// tencent/Hunyuan-A13B-Instruct
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the chat template of hunyuan a13b shoule be a much complex one ? with a quick and slow think option.

also the model default enable the slow think,

does llama cpp have some option on enable_think like the huggingface exmaple ?

for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
} else if (role == "assistant") {
ss << "<|startoftext|>" << message->content << "<|eos|>";
} else {
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
if (add_ass) {
ss << "<|startoftext|>";
}
} else {
// template not supported
return -1;
Expand Down
1 change: 1 addition & 0 deletions src/llama-chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

Expand Down
Loading
Loading