Skip to content

Commit 39bf692

Browse files
pwilkinCISC
andauthored
[Model] Qwen3.5 dense and MoE support (no vision) (ggml-org#19435)
* Unified delta net handling * Remove old methods. * Refactor and optimize * Adapt autoregressive version from @ymcki * Change to decay mask approach * Fix bad permute * Qwen 3.5 support * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Further fixes * Use inheritance, remove unneeded conts * Not like this! * Remove ggml.h explicit import * Remove transformers, fix the views * ACTUALLY fix views, make super calls explicit in conversion. * Fix conversion again * Remove extra ggml.h imports --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent e06088d commit 39bf692

14 files changed

Lines changed: 1532 additions & 399 deletions

convert_hf_to_gguf.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4102,39 +4102,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
41024102
# process the experts separately
41034103
name = name.replace("language_model.", "") # InternVL
41044104

4105-
# handle aggregated expert tensors
4106-
# GGUF stores dimensions reversed from PyTorch, so:
4107-
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
4108-
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
4109-
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
4105+
# handle pre-packed expert tensors (e.g. Qwen3.5 MoE, Qwen3Next)
4106+
# HF stores these using nn.Linear convention: [n_expert, out_features, in_features]
4107+
# This matches the individual expert stacking path below (which stacks
4108+
# per-expert [out, in] weights into [n_expert, out, in]), so no permute is needed.
41104109
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
41114110
mapped = f"{name}.weight" if not name.endswith(".weight") else name
4112-
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
4113-
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
4114-
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
4115-
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
4116-
permuted = data_torch.permute(0, 2, 1).contiguous()
4117-
yield from super().modify_tensors(permuted, mapped, bid)
4111+
# HF: [n_expert, n_embd, n_ff] → GGML: {n_ff, n_embd, n_expert} ✓
4112+
yield from super().modify_tensors(data_torch, mapped, bid)
41184113
return
41194114

41204115
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
4121-
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
4122-
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
4123-
split_dim = data_torch.shape[-1] // 2
4124-
gate = data_torch[..., :split_dim].contiguous()
4125-
up = data_torch[..., split_dim:].contiguous()
4126-
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
4127-
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
4128-
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
4129-
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
4130-
base_name = name.removesuffix(".weight")
4131-
base = base_name.rsplit('.', 1)[0]
4132-
mapped_gate = f"{base}.gate_proj.weight"
4133-
mapped_up = f"{base}.up_proj.weight"
4134-
perm_gate = gate.permute(0, 2, 1).contiguous()
4135-
perm_up = up.permute(0, 2, 1).contiguous()
4136-
yield from super().modify_tensors(perm_gate, mapped_gate, bid)
4137-
yield from super().modify_tensors(perm_up, mapped_up, bid)
4116+
# HF: [n_expert, 2*n_ff, n_embd] → split on dim=1
4117+
n_ff = data_torch.shape[1] // 2
4118+
gate = data_torch[:, :n_ff, :].contiguous()
4119+
up = data_torch[:, n_ff:, :].contiguous()
4120+
# gate/up: [n_expert, n_ff, n_embd] → GGML: {n_embd, n_ff, n_expert} ✓
4121+
base_name = name.removesuffix(".weight").removesuffix(".gate_up_proj")
4122+
mapped_gate = f"{base_name}.gate_proj.weight"
4123+
mapped_up = f"{base_name}.up_proj.weight"
4124+
yield from super().modify_tensors(gate, mapped_gate, bid)
4125+
yield from super().modify_tensors(up, mapped_up, bid)
41384126
return
41394127

41404128
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
@@ -4344,6 +4332,40 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43444332
yield from super().modify_tensors(data_torch, name, bid)
43454333

43464334

4335+
@ModelBase.register("Qwen3_5ForCausalLM", "Qwen3_5TextForCausalLM")
4336+
class Qwen3_5Model(Qwen3NextModel):
4337+
model_arch = gguf.MODEL_ARCH.QWEN3_5
4338+
4339+
# Stores whichever of in_proj_a/in_proj_b is seen first, keyed by layer
4340+
_pending_ba: dict[int | None, tuple[str, Tensor]] = {}
4341+
4342+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4343+
# Handle split in_proj_b + in_proj_a → concatenated SSM_BETA_ALPHA
4344+
# safetensors sorts alphabetically so in_proj_a arrives before in_proj_b
4345+
if "in_proj_a.weight" in name or "in_proj_b.weight" in name:
4346+
which = "a" if "in_proj_a" in name else "b"
4347+
if bid not in self._pending_ba:
4348+
self._pending_ba[bid] = (which, data_torch)
4349+
return
4350+
prev_which, prev_tensor = self._pending_ba.pop(bid)
4351+
assert prev_which != which, f"duplicate in_proj_{which} for layer {bid}"
4352+
b_tensor = prev_tensor if prev_which == "b" else data_torch
4353+
a_tensor = prev_tensor if prev_which == "a" else data_torch
4354+
ba_combined = torch.cat([b_tensor, a_tensor], dim=0)
4355+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.SSM_BETA_ALPHA, bid, ".weight"), ba_combined)
4356+
return
4357+
else:
4358+
# Qwen3Next uses .qkvz tensor, so we use the super to get the other functionalities
4359+
# (norm correction, A_log to A etc.) for free
4360+
# Qwen2Moe already does the gate_up conversion properly, just use that
4361+
yield from super().modify_tensors(data_torch, name, bid)
4362+
4363+
4364+
@ModelBase.register("Qwen3_5MoeForCausalLM", "Qwen3_5MoeTextForCausalLM")
4365+
class Qwen3_5MoeModel(Qwen3_5Model):
4366+
model_arch = gguf.MODEL_ARCH.QWEN3_5_MOE
4367+
4368+
43474369
@ModelBase.register("RND1")
43484370
class RND1Model(Qwen2MoeModel):
43494371
model_arch = gguf.MODEL_ARCH.RND1

gguf-py/gguf/constants.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ class MODEL_ARCH(IntEnum):
382382
QWEN3 = auto()
383383
QWEN3MOE = auto()
384384
QWEN3NEXT = auto()
385+
QWEN3_5 = auto()
386+
QWEN3_5_MOE = auto()
385387
QWEN3VL = auto()
386388
QWEN3VLMOE = auto()
387389
PHI2 = auto()
@@ -812,6 +814,8 @@ class MODEL_TENSOR(IntEnum):
812814
MODEL_ARCH.QWEN3: "qwen3",
813815
MODEL_ARCH.QWEN3MOE: "qwen3moe",
814816
MODEL_ARCH.QWEN3NEXT: "qwen3next",
817+
MODEL_ARCH.QWEN3_5: "qwen3_5",
818+
MODEL_ARCH.QWEN3_5_MOE: "qwen3_5moe",
815819
MODEL_ARCH.QWEN3VL: "qwen3vl",
816820
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
817821
MODEL_ARCH.PHI2: "phi2",
@@ -1784,6 +1788,61 @@ class MODEL_TENSOR(IntEnum):
17841788
MODEL_TENSOR.SSM_BETA_ALPHA,
17851789
MODEL_TENSOR.SSM_OUT
17861790
],
1791+
MODEL_ARCH.QWEN3_5: [
1792+
MODEL_TENSOR.TOKEN_EMBD,
1793+
MODEL_TENSOR.OUTPUT_NORM,
1794+
MODEL_TENSOR.OUTPUT,
1795+
MODEL_TENSOR.ATTN_NORM,
1796+
MODEL_TENSOR.ATTN_Q,
1797+
MODEL_TENSOR.ATTN_Q_NORM,
1798+
MODEL_TENSOR.ATTN_K,
1799+
MODEL_TENSOR.ATTN_K_NORM,
1800+
MODEL_TENSOR.ATTN_V,
1801+
MODEL_TENSOR.ATTN_OUT,
1802+
MODEL_TENSOR.ATTN_POST_NORM,
1803+
MODEL_TENSOR.ATTN_GATE,
1804+
MODEL_TENSOR.ATTN_QKV,
1805+
MODEL_TENSOR.FFN_GATE,
1806+
MODEL_TENSOR.FFN_DOWN,
1807+
MODEL_TENSOR.FFN_UP,
1808+
MODEL_TENSOR.SSM_A,
1809+
MODEL_TENSOR.SSM_CONV1D,
1810+
MODEL_TENSOR.SSM_DT,
1811+
MODEL_TENSOR.SSM_NORM,
1812+
MODEL_TENSOR.SSM_IN,
1813+
MODEL_TENSOR.SSM_BETA_ALPHA,
1814+
MODEL_TENSOR.SSM_OUT,
1815+
],
1816+
MODEL_ARCH.QWEN3_5_MOE: [
1817+
MODEL_TENSOR.TOKEN_EMBD,
1818+
MODEL_TENSOR.OUTPUT_NORM,
1819+
MODEL_TENSOR.OUTPUT,
1820+
MODEL_TENSOR.ATTN_NORM,
1821+
MODEL_TENSOR.ATTN_Q,
1822+
MODEL_TENSOR.ATTN_Q_NORM,
1823+
MODEL_TENSOR.ATTN_K,
1824+
MODEL_TENSOR.ATTN_K_NORM,
1825+
MODEL_TENSOR.ATTN_V,
1826+
MODEL_TENSOR.ATTN_OUT,
1827+
MODEL_TENSOR.ATTN_POST_NORM,
1828+
MODEL_TENSOR.ATTN_GATE,
1829+
MODEL_TENSOR.ATTN_QKV,
1830+
MODEL_TENSOR.FFN_GATE_INP,
1831+
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
1832+
MODEL_TENSOR.FFN_UP_SHEXP,
1833+
MODEL_TENSOR.FFN_DOWN_SHEXP,
1834+
MODEL_TENSOR.FFN_GATE_SHEXP,
1835+
MODEL_TENSOR.FFN_DOWN_EXP,
1836+
MODEL_TENSOR.FFN_UP_EXP,
1837+
MODEL_TENSOR.FFN_GATE_EXP,
1838+
MODEL_TENSOR.SSM_A,
1839+
MODEL_TENSOR.SSM_CONV1D,
1840+
MODEL_TENSOR.SSM_DT,
1841+
MODEL_TENSOR.SSM_NORM,
1842+
MODEL_TENSOR.SSM_IN,
1843+
MODEL_TENSOR.SSM_BETA_ALPHA,
1844+
MODEL_TENSOR.SSM_OUT,
1845+
],
17871846
MODEL_ARCH.QWEN3VL: [
17881847
MODEL_TENSOR.TOKEN_EMBD,
17891848
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class TensorNameMap:
228228
"transformer_encoder.{bid}.qkv", # neobert
229229
"layers.{bid}.attn.Wqkv", # modern-bert
230230
"model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm
231+
"model.layers.{bid}.linear_attn.in_proj_qkv", # qwen3.5
231232
),
232233

233234
# Attention query
@@ -358,8 +359,9 @@ class TensorNameMap:
358359
),
359360

360361
MODEL_TENSOR.ATTN_GATE: (
361-
"model.layers.{bid}.self_attn.gate_proj", # afmoe
362-
"model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate
362+
"model.layers.{bid}.self_attn.gate_proj", # afmoe
363+
"model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate
364+
"model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5
363365
),
364366

365367
# Feed-forward norm

src/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ add_library(llama
5757
models/deci.cpp
5858
models/deepseek.cpp
5959
models/deepseek2.cpp
60+
models/delta.cpp
6061
models/dots1.cpp
6162
models/dream.cpp
6263
models/ernie4-5-moe.cpp
@@ -122,6 +123,8 @@ add_library(llama
122123
models/qwen3vl-moe.cpp
123124
models/qwen3moe.cpp
124125
models/qwen3next.cpp
126+
models/qwen3-5.cpp
127+
models/qwen3-5moe.cpp
125128
models/refact.cpp
126129
models/rnd1.cpp
127130
models/rwkv6-base.cpp

src/llama-arch.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
3535
{ LLM_ARCH_QWEN3, "qwen3" },
3636
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
3737
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
38+
{ LLM_ARCH_QWEN3_5, "qwen3_5" },
39+
{ LLM_ARCH_QWEN3_5_MOE, "qwen3_5moe" },
3840
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
3941
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
4042
{ LLM_ARCH_PHI2, "phi2" },
@@ -985,6 +987,63 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
985987
LLM_TENSOR_SSM_NORM,
986988
LLM_TENSOR_SSM_OUT,
987989
};
990+
case LLM_ARCH_QWEN3_5:
991+
return {
992+
LLM_TENSOR_TOKEN_EMBD,
993+
LLM_TENSOR_OUTPUT_NORM,
994+
LLM_TENSOR_OUTPUT,
995+
LLM_TENSOR_ATTN_NORM,
996+
LLM_TENSOR_ATTN_POST_NORM,
997+
LLM_TENSOR_ATTN_Q,
998+
LLM_TENSOR_ATTN_Q_NORM,
999+
LLM_TENSOR_ATTN_K,
1000+
LLM_TENSOR_ATTN_K_NORM,
1001+
LLM_TENSOR_ATTN_V,
1002+
LLM_TENSOR_ATTN_OUT,
1003+
LLM_TENSOR_ATTN_QKV,
1004+
LLM_TENSOR_ATTN_GATE,
1005+
LLM_TENSOR_FFN_GATE,
1006+
LLM_TENSOR_FFN_DOWN,
1007+
LLM_TENSOR_FFN_UP,
1008+
LLM_TENSOR_SSM_A_NOSCAN,
1009+
LLM_TENSOR_SSM_CONV1D,
1010+
LLM_TENSOR_SSM_DT,
1011+
LLM_TENSOR_SSM_BETA_ALPHA,
1012+
LLM_TENSOR_SSM_IN,
1013+
LLM_TENSOR_SSM_NORM,
1014+
LLM_TENSOR_SSM_OUT,
1015+
};
1016+
case LLM_ARCH_QWEN3_5_MOE:
1017+
return {
1018+
LLM_TENSOR_TOKEN_EMBD,
1019+
LLM_TENSOR_OUTPUT_NORM,
1020+
LLM_TENSOR_OUTPUT,
1021+
LLM_TENSOR_ATTN_NORM,
1022+
LLM_TENSOR_ATTN_POST_NORM,
1023+
LLM_TENSOR_ATTN_Q,
1024+
LLM_TENSOR_ATTN_Q_NORM,
1025+
LLM_TENSOR_ATTN_K,
1026+
LLM_TENSOR_ATTN_K_NORM,
1027+
LLM_TENSOR_ATTN_V,
1028+
LLM_TENSOR_ATTN_OUT,
1029+
LLM_TENSOR_ATTN_QKV,
1030+
LLM_TENSOR_ATTN_GATE,
1031+
LLM_TENSOR_FFN_GATE_INP,
1032+
LLM_TENSOR_FFN_GATE_EXPS,
1033+
LLM_TENSOR_FFN_DOWN_EXPS,
1034+
LLM_TENSOR_FFN_UP_EXPS,
1035+
LLM_TENSOR_FFN_GATE_INP_SHEXP,
1036+
LLM_TENSOR_FFN_GATE_SHEXP,
1037+
LLM_TENSOR_FFN_DOWN_SHEXP,
1038+
LLM_TENSOR_FFN_UP_SHEXP,
1039+
LLM_TENSOR_SSM_A_NOSCAN,
1040+
LLM_TENSOR_SSM_CONV1D,
1041+
LLM_TENSOR_SSM_DT,
1042+
LLM_TENSOR_SSM_BETA_ALPHA,
1043+
LLM_TENSOR_SSM_IN,
1044+
LLM_TENSOR_SSM_NORM,
1045+
LLM_TENSOR_SSM_OUT,
1046+
};
9881047
case LLM_ARCH_QWEN3VL:
9891048
case LLM_ARCH_CHAMELEON:
9901049
case LLM_ARCH_HUNYUAN_DENSE:
@@ -2674,6 +2733,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
26742733
case LLM_ARCH_NEMOTRON_H:
26752734
case LLM_ARCH_NEMOTRON_H_MOE:
26762735
case LLM_ARCH_QWEN3NEXT:
2736+
case LLM_ARCH_QWEN3_5:
2737+
case LLM_ARCH_QWEN3_5_MOE:
26772738
case LLM_ARCH_KIMI_LINEAR:
26782739
return true;
26792740
default:

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ enum llm_arch {
3939
LLM_ARCH_QWEN3,
4040
LLM_ARCH_QWEN3MOE,
4141
LLM_ARCH_QWEN3NEXT,
42+
LLM_ARCH_QWEN3_5,
43+
LLM_ARCH_QWEN3_5_MOE,
4244
LLM_ARCH_QWEN3VL,
4345
LLM_ARCH_QWEN3VLMOE,
4446
LLM_ARCH_PHI2,

src/llama-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2013,7 +2013,7 @@ void llama_context::output_reorder() {
20132013
//
20142014

20152015
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
2016-
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) {
2016+
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN3_5 || model.arch == LLM_ARCH_QWEN3_5_MOE || model.arch == LLM_ARCH_KIMI_LINEAR) {
20172017
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
20182018
}
20192019
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());

0 commit comments

Comments
 (0)