-
Notifications
You must be signed in to change notification settings - Fork 12.3k
model : add hunyuan moe #14425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
model : add hunyuan moe #14425
Changes from all commits
f5d8a22
38acf7f
35591a9
cb1f9f2
51886a4
cff16cc
5e78e88
d219580
616f4c7
0fd3930
b19ecae
245db15
3920faa
8fd547b
34cc679
4d66bdc
b20bd26
99d9e94
1221d94
5471f5a
46c8b70
443ec9b
251e78a
06cab8f
5cfc73b
2d56a29
e5fe089
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -815,6 +815,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: | |
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": | ||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 | ||
res = "minerva-7b" | ||
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": | ||
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct | ||
res = "hunyuan" | ||
|
||
if res is None: | ||
logger.warning("\n") | ||
|
@@ -6436,6 +6439,155 @@ def set_gguf_parameters(self): | |
super().set_gguf_parameters() | ||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) | ||
|
||
|
||
@ModelBase.register("HunYuanMoEV1ForCausalLM") | ||
class HunYuanMoEModel(TextModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you align with Hunyuan's naming , with version V1 suffix? |
||
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also could you add the version suffix on the arch name, like the arch name in model 's config.json ? |
||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
# For handling tied embeddings | ||
self._tok_embd = None | ||
|
||
def set_vocab(self): | ||
from transformers import AutoTokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) | ||
|
||
# 1. Get the pre-tokenizer identifier hash | ||
tokpre = self.get_vocab_base_pre(tokenizer) | ||
|
||
# 2. Reverse-engineer the merges list from mergeable_ranks | ||
merges = [] | ||
vocab = {} | ||
mergeable_ranks = tokenizer.mergeable_ranks | ||
for token, rank in mergeable_ranks.items(): | ||
vocab[QwenModel.token_bytes_to_string(token)] = rank | ||
if len(token) == 1: | ||
continue | ||
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) | ||
if len(merged) == 2: # todo this is an assert in Qwen, why? | ||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) | ||
|
||
# 3. Generate the tokens and toktypes lists | ||
vocab_size = self.hparams["vocab_size"] | ||
assert tokenizer.vocab_size == vocab_size | ||
special_tokens = tokenizer.special_tokens | ||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} | ||
tokens: list[str] = [] | ||
toktypes: list[int] = [] | ||
for i in range(vocab_size): | ||
if i not in reverse_vocab: | ||
tokens.append(f"[PAD{i}]") | ||
toktypes.append(gguf.TokenType.UNUSED) | ||
else: | ||
token = reverse_vocab[i] | ||
tokens.append(token) | ||
if i in special_tokens.values(): | ||
toktypes.append(gguf.TokenType.CONTROL) | ||
else: | ||
toktypes.append(gguf.TokenType.NORMAL) | ||
|
||
# 4. Write all vocab-related fields to the GGUF writer | ||
self.gguf_writer.add_tokenizer_model("gpt2") | ||
self.gguf_writer.add_tokenizer_pre(tokpre) | ||
self.gguf_writer.add_token_list(tokens) | ||
self.gguf_writer.add_token_types(toktypes) | ||
self.gguf_writer.add_token_merges(merges) | ||
|
||
# 5. Add special tokens and chat templates | ||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) | ||
special_vocab.add_to_gguf(self.gguf_writer) | ||
# FIX for BOS token: Overwrite incorrect id read from config.json | ||
self.gguf_writer.add_bos_token_id(127959) # <|bos|> | ||
|
||
def set_gguf_parameters(self): | ||
super().set_gguf_parameters() | ||
hparams = self.hparams | ||
|
||
self.gguf_writer.add_expert_count(hparams["num_experts"]) | ||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) | ||
|
||
moe_intermediate_size = hparams["moe_intermediate_size"] | ||
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) | ||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) | ||
|
||
moe_topk = hparams["moe_topk"] | ||
assert all(topk == moe_topk[0] for topk in moe_topk) | ||
self.gguf_writer.add_expert_used_count(moe_topk[0]) | ||
|
||
moe_shared_expert = hparams["num_shared_expert"] | ||
assert all(n == moe_shared_expert[0] for n in moe_shared_expert) | ||
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) | ||
|
||
# Rope | ||
rope_scaling = hparams.get("rope_scaling", {}) | ||
if rope_scaling.get("type") == "dynamic": | ||
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ | ||
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf) | ||
alpha = rope_scaling.get("alpha", 1000) | ||
base = hparams.get("rope_theta", 10000.0) | ||
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128 | ||
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251 | ||
self.gguf_writer.add_rope_freq_base(scaled_base) | ||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) | ||
self.gguf_writer.add_rope_scaling_factor(1) | ||
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k | ||
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length | ||
self.gguf_writer.add_context_length(256 * 1024) # 256k context length | ||
|
||
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated | ||
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ | ||
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" | ||
|
||
_experts: list[dict[str, Tensor]] | None = None | ||
|
||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||
if name == "model.embed_tokens.weight": | ||
self._tok_embd = data_torch.clone() | ||
|
||
if name == "lm_head.weight": | ||
if self.hparams.get("tie_word_embeddings", False): | ||
logger.info("Skipping tied output layer 'lm_head.weight'") | ||
return [] | ||
|
||
if name.find("mlp.experts") != -1: | ||
n_experts = self.hparams["num_experts"] | ||
assert bid is not None | ||
|
||
if self._experts is None: | ||
self._experts = [{} for _ in range(self.block_count)] | ||
|
||
self._experts[bid][name] = data_torch | ||
|
||
if len(self._experts[bid]) >= n_experts * 3: | ||
# merge the experts into a single 3d tensor | ||
tensors: list[tuple[str, Tensor]] = [] | ||
for w_name in ["down_proj", "gate_proj", "up_proj"]: | ||
datas: list[Tensor] = [] | ||
|
||
for xid in range(n_experts): | ||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" | ||
datas.append(self._experts[bid][ename]) | ||
del self._experts[bid][ename] | ||
|
||
data_torch = torch.stack(datas, dim=0) | ||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" | ||
new_name = self.map_tensor_name(merged_name) | ||
tensors.append((new_name, data_torch)) | ||
|
||
return tensors | ||
else: | ||
return [] | ||
|
||
return [(self.map_tensor_name(name), data_torch)] | ||
|
||
def prepare_tensors(self): | ||
super().prepare_tensors() | ||
if self._experts is not None: | ||
experts = [k for d in self._experts for k in d.keys()] | ||
if len(experts) > 0: | ||
raise ValueError(f"Unprocessed experts: {experts}") | ||
|
||
###### CONVERSION LOGIC ###### | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,6 +137,7 @@ class TOKENIZER_TYPE(IntEnum): | |
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, | ||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, | ||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, | ||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. model name should be hunyuan a13b, from my source , they will release more llm model soon, we'd better add some identify for the mdoel. |
||
] | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -355,6 +355,7 @@ class MODEL_ARCH(IntEnum): | |
DOTS1 = auto() | ||
ARCEE = auto() | ||
ERNIE4_5 = auto() | ||
HUNYUAN_MOE = auto() | ||
|
||
|
||
class VISION_PROJECTOR_TYPE(IntEnum): | ||
|
@@ -656,6 +657,7 @@ class MODEL_TENSOR(IntEnum): | |
MODEL_ARCH.DOTS1: "dots1", | ||
MODEL_ARCH.ARCEE: "arcee", | ||
MODEL_ARCH.ERNIE4_5: "ernie4_5", | ||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hunyuan-moe-v1 will be a better name for later model updating. |
||
} | ||
|
||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { | ||
|
@@ -2193,6 +2195,27 @@ class MODEL_TENSOR(IntEnum): | |
MODEL_TENSOR.FFN_DOWN, | ||
MODEL_TENSOR.FFN_UP, | ||
], | ||
MODEL_ARCH.HUNYUAN_MOE: [ | ||
MODEL_TENSOR.TOKEN_EMBD, | ||
MODEL_TENSOR.OUTPUT_NORM, | ||
MODEL_TENSOR.OUTPUT, | ||
MODEL_TENSOR.ROPE_FREQS, | ||
MODEL_TENSOR.ATTN_NORM, | ||
MODEL_TENSOR.ATTN_Q, | ||
MODEL_TENSOR.ATTN_Q_NORM, | ||
MODEL_TENSOR.ATTN_K, | ||
MODEL_TENSOR.ATTN_K_NORM, | ||
MODEL_TENSOR.ATTN_V, | ||
MODEL_TENSOR.ATTN_OUT, | ||
MODEL_TENSOR.FFN_GATE_INP, | ||
MODEL_TENSOR.FFN_NORM, | ||
MODEL_TENSOR.FFN_GATE_EXP, | ||
MODEL_TENSOR.FFN_DOWN_EXP, | ||
MODEL_TENSOR.FFN_UP_EXP, | ||
MODEL_TENSOR.FFN_GATE_SHEXP, | ||
MODEL_TENSOR.FFN_DOWN_SHEXP, | ||
MODEL_TENSOR.FFN_UP_SHEXP, | ||
], | ||
# TODO | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,6 +117,7 @@ extern "C" { | |
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, | ||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, | ||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, | ||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a version suffix on vocab type will be better. |
||
}; | ||
|
||
enum llama_rope_type { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,6 +77,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { | |
{ LLM_ARCH_DOTS1, "dots1" }, | ||
{ LLM_ARCH_ARCEE, "arcee" }, | ||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" }, | ||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also . |
||
{ LLM_ARCH_UNKNOWN, "(unknown)" }, | ||
}; | ||
|
||
|
@@ -1676,6 +1677,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N | |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, | ||
}, | ||
}, | ||
{ | ||
LLM_ARCH_HUNYUAN_MOE, | ||
{ | ||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" }, | ||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" }, | ||
{ LLM_TENSOR_OUTPUT, "output" }, | ||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, | ||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, | ||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, | ||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, | ||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, | ||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, | ||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, | ||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, | ||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, | ||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, | ||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, | ||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, | ||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, | ||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, | ||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, | ||
}, | ||
}, | ||
{ | ||
LLM_ARCH_UNKNOWN, | ||
{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { | |
{ "bailing", LLM_CHAT_TEMPLATE_BAILING }, | ||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, | ||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, | ||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, | ||
}; | ||
|
||
llm_chat_template llm_chat_template_from_str(const std::string & name) { | ||
|
@@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { | |
return LLM_CHAT_TEMPLATE_LLAMA4; | ||
} else if (tmpl_contains("<|endofuserprompt|>")) { | ||
return LLM_CHAT_TEMPLATE_DOTS1; | ||
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { | ||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; | ||
} | ||
return LLM_CHAT_TEMPLATE_UNKNOWN; | ||
} | ||
|
@@ -665,6 +668,21 @@ int32_t llm_chat_apply_template( | |
if (add_ass) { | ||
ss << "<|response|>"; | ||
} | ||
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) { | ||
// tencent/Hunyuan-A13B-Instruct | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the chat template of hunyuan a13b shoule be a much complex one ? with a quick and slow think option. also the model default enable the slow think, does llama cpp have some option on enable_think like the huggingface exmaple ? |
||
for (auto message : chat) { | ||
std::string role(message->role); | ||
if (role == "system") { | ||
ss << "<|startoftext|>" << message->content << "<|extra_4|>"; | ||
} else if (role == "assistant") { | ||
ss << "<|startoftext|>" << message->content << "<|eos|>"; | ||
} else { | ||
ss << "<|startoftext|>" << message->content << "<|extra_0|>"; | ||
} | ||
} | ||
if (add_ass) { | ||
ss << "<|startoftext|>"; | ||
} | ||
} else { | ||
// template not supported | ||
return -1; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model name better with hunyuan A13B