Skip to content

model : gemma3n text-only #14400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 118 additions & 6 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.POSNET_NORM2,
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
)
)
or not new_name.endswith(".weight")
Expand All @@ -320,7 +322,11 @@ def prepare_tensors(self):
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.TOKEN_EMBD,
gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
gguf.MODEL_TENSOR.OUTPUT,
gguf.MODEL_TENSOR.ALTUP_ROUTER,
gguf.MODEL_TENSOR.LAUREL_L,
gguf.MODEL_TENSOR.LAUREL_R,
)
):
if self.ftype in (
Expand Down Expand Up @@ -921,13 +927,16 @@ def _create_vocab_sentencepiece(self):
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))

vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
vocab_size = self.find_hparam([
"vocab_size_per_layer_input", # gemma3n
"vocab_size",
], optional=True) or tokenizer.vocab_size()

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size

for token_id in range(tokenizer.vocab_size()):
for token_id in range(vocab_size):
piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)
Expand All @@ -942,6 +951,10 @@ def _create_vocab_sentencepiece(self):
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE

if token_id >= vocab_size:
logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
break

tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
Expand Down Expand Up @@ -4217,6 +4230,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
class Gemma3Model(TextModel):
model_arch = gguf.MODEL_ARCH.GEMMA3
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value

def set_vocab(self):
self._set_vocab_sentencepiece()
Expand All @@ -4238,9 +4252,8 @@ def set_gguf_parameters(self):
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
# attn_logit_softcapping is removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
assert hparams.get("final_logit_softcapping") is None
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
Expand All @@ -4252,7 +4265,7 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.startswith("language_model."):
if "language_model." in name:
name = name.replace("language_model.", "")

elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
Expand All @@ -4267,8 +4280,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

# ref code in Gemma3RMSNorm
# output = output * (1.0 + self.weight.float())
# note: this is not the case on gemma3n
if name.endswith("norm.weight"):
data_torch = data_torch + 1
data_torch = data_torch + self.norm_shift

return [(self.map_tensor_name(name), data_torch)]

Expand Down Expand Up @@ -4325,6 +4339,104 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [] # skip other tensors


@ModelBase.register("Gemma3nForConditionalGeneration")
class Gemma3NModel(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA3N
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code

_altup_proj: list[Tensor] = []
_altup_unembd: list[Tensor] = []

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
self._altup_proj = [
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
]
self._altup_unembd = [
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
]

def set_vocab(self):
with open(self.dir_model / "chat_template.jinja") as f:
# quick hack to make sure chat template is added
self.gguf_writer.add_chat_template(f.read())
super().set_vocab()

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"])
self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"])
self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"])

activation_sparsity_scale = []
for s in self.hparams["activation_sparsity_pattern"]:
normal_dist = torch.distributions.normal.Normal(0, 1)
std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32))
activation_sparsity_scale.append(std_multiplier.item())
self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale)

sliding_window_pattern = []
for t in self.hparams["layer_types"]:
sliding_window_pattern.append(t == "sliding_attention")
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)

def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
has_all = all(m.numel() > 0 for m in matrices)
if not has_all:
return None
else:
return torch.stack(matrices, dim=0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.endswith("_scale"):
name = name + ".weight"

# TODO: implement self.prediction_coefs.weight.clamp_(...)

if "language_model." not in name:
return [] # skip non-language model tensors

if "altup_unembed_projections" in name:
data_torch = data_torch.to(device="cpu")
if ".0." in name:
self._altup_unembd[0] = data_torch
elif ".1." in name:
self._altup_unembd[1] = data_torch
elif ".2." in name:
self._altup_unembd[2] = data_torch
else:
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_unembd)
if out is not None:
return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
else:
return []

if "altup_projections" in name:
data_torch = data_torch.to(device="cpu")
if ".0." in name:
self._altup_proj[0] = data_torch
elif ".1." in name:
self._altup_proj[1] = data_torch
elif ".2." in name:
self._altup_proj[2] = data_torch
else:
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_proj)
if out is not None:
return [(self.map_tensor_name("model.altup_projections.weight"), out)]
else:
return []

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Starcoder2ForCausalLM")
class StarCoder2Model(TextModel):
model_arch = gguf.MODEL_ARCH.STARCODER2
Expand Down
75 changes: 75 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ class LLM:
EMBEDDING_SCALE = "{arch}.embedding_scale"
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand All @@ -142,6 +146,8 @@ class Attention:
SCALE = "{arch}.attention.scale"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand Down Expand Up @@ -314,6 +320,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
GEMMA2 = auto()
GEMMA3 = auto()
GEMMA3N = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
Expand Down Expand Up @@ -399,6 +406,22 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
PER_LAYER_TOKEN_EMBD = auto() # gemma3n
PER_LAYER_MODEL_PROJ = auto() # gemma3n
PER_LAYER_INP_GATE = auto() # gemma3n
PER_LAYER_PROJ = auto() # gemma3n
PER_LAYER_PROJ_NORM = auto() # gemma3n
PER_LAYER_POST_NORM = auto() # gemma3n
ALTUP_PROJ = auto() # gemma3n
ALTUP_UNEMBD_PROJ = auto() # gemma3n
ALTUP_CORRECT_COEF = auto() # gemma3n
ALTUP_CORRECT_SCALE = auto() # gemma3n
ALTUP_PREDICT_COEF = auto() # gemma3n
ALTUP_ROUTER = auto() # gemma3n
ALTUP_ROUTER_NORM = auto() # gemma3n
LAUREL_L = auto() # gemma3n
LAUREL_R = auto() # gemma3n
LAUREL_POST_NORM = auto() # gemma3n
SSM_IN = auto()
SSM_CONV1D = auto()
SSM_X = auto()
Expand Down Expand Up @@ -597,6 +620,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
Expand Down Expand Up @@ -682,6 +706,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n
MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj", # gemma3n
MODEL_TENSOR.ALTUP_PROJ: "altup_proj", # gemma3n
MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", # gemma3n
MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", # gemma3n
MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", # gemma3n
MODEL_TENSOR.ALTUP_CORRECT_COEF: "blk.{bid}.altup_correct_coef", # gemma3n
MODEL_TENSOR.ALTUP_CORRECT_SCALE: "blk.{bid}.altup_correct_scale", # gemma3n
MODEL_TENSOR.ALTUP_PREDICT_COEF: "blk.{bid}.altup_predict_coef", # gemma3n
MODEL_TENSOR.ALTUP_ROUTER: "blk.{bid}.altup_router", # gemma3n
MODEL_TENSOR.ALTUP_ROUTER_NORM: "blk.{bid}.altup_router_norm", # gemma3n
MODEL_TENSOR.LAUREL_L: "blk.{bid}.laurel_l", # gemma3n
MODEL_TENSOR.LAUREL_R: "blk.{bid}.laurel_r", # gemma3n
MODEL_TENSOR.LAUREL_POST_NORM: "blk.{bid}.laurel_post_norm", # gemma3n
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
Expand Down Expand Up @@ -1486,6 +1526,41 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GEMMA3N: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
# altup / laurel
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
MODEL_TENSOR.PER_LAYER_INP_GATE,
MODEL_TENSOR.PER_LAYER_PROJ,
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
MODEL_TENSOR.ALTUP_PROJ,
MODEL_TENSOR.ALTUP_UNEMBD_PROJ,
MODEL_TENSOR.ALTUP_CORRECT_COEF,
MODEL_TENSOR.ALTUP_CORRECT_SCALE,
MODEL_TENSOR.ALTUP_PREDICT_COEF,
MODEL_TENSOR.ALTUP_ROUTER,
MODEL_TENSOR.ALTUP_ROUTER_NORM,
MODEL_TENSOR.LAUREL_L,
MODEL_TENSOR.LAUREL_R,
MODEL_TENSOR.LAUREL_POST_NORM,
],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
18 changes: 18 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,18 @@ def add_parallel_residual(self, use: bool) -> None:
def add_decoder_start_token_id(self, id: int) -> None:
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)

def add_embedding_length_per_layer_input(self, value: int) -> None:
self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)

def add_altup_active_idx(self, val: int) -> None:
self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)

def add_altup_num_inputs(self, val: int) -> None:
self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)

def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)

def add_head_count(self, count: int | Sequence[int]) -> None:
if isinstance(count, int):
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
Expand Down Expand Up @@ -702,6 +714,12 @@ def add_max_alibi_bias(self, bias: float) -> None:
def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)

def add_shared_kv_layers(self, value: float) -> None:
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)

def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)

def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)

Expand Down
Loading
Loading