Skip to content

Add plamo2 #13930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
178 changes: 178 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3476,6 +3476,184 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(new_name, data_torch)]


@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM")
class Plamo2Model(TextModel):
model_arch = gguf.MODEL_ARCH.PLAMO2

def set_vocab(self):
# PLaMo 2 uses a custom tokenizer with a .jsonl file
# We need to handle this specially
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
tokenizer_config_path = self.dir_model / "tokenizer_config.json"

if not tokenizer_jsonl_path.is_file():
raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")

# Load tokenizer config
with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
tokenizer_config = json.load(f)

# Load tokens from JSONL file (actually a list format)
tokens = []
scores = []
toktypes = []

with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
if line.strip():
token_data = json.loads(line)
# Format: [token, score, type, ?, ?, ?, ?]
token = token_data[0].encode("utf-8")
score = float(token_data[1])
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"

tokens.append(token)
scores.append(score)

# Map token type strings to GGUF token types
if token_type_str == "UNKNOWN":
toktypes.append(gguf.TokenType.UNKNOWN)
elif token_type_str == "CONTROL":
toktypes.append(gguf.TokenType.CONTROL)
elif token_type_str == "BYTE":
toktypes.append(gguf.TokenType.BYTE)
else:
# Check for PLaMo-2 special tokens
token_str = token_data[0]
if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)

# Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
self.gguf_writer.add_tokenizer_model("plamo2")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

# Add special tokens from config
if "bos_token_id" in tokenizer_config:
self.gguf_writer.add_bos_token_id(tokenizer_config["bos_token_id"])
if "eos_token_id" in tokenizer_config:
self.gguf_writer.add_eos_token_id(tokenizer_config["eos_token_id"])
if "pad_token_id" in tokenizer_config:
self.gguf_writer.add_pad_token_id(tokenizer_config["pad_token_id"])
if "unk_token_id" in tokenizer_config:
self.gguf_writer.add_unk_token_id(tokenizer_config["unk_token_id"])

self.gguf_writer.add_add_space_prefix(False)

def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]

self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_group_norm_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_layer_norm_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))

# Mamba parameters
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4))
self.gguf_writer.add_ssm_num_heads(hparams.get("mamba_num_heads", 64))
self.gguf_writer.add_ssm_head_dim(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_ssm_inner_size(hparams.get("hidden_size_per_head", 128) * hparams.get("mamba_num_heads", 64))
self.gguf_writer.add_ssm_time_step_rank(hparams.get("time_step_limit", 192))
self.gguf_writer.add_ssm_dt_min(hparams.get("time_step_min", 0.001))
self.gguf_writer.add_ssm_dt_max(hparams.get("time_step_max", 0.1))
self.gguf_writer.add_hybrid_mamba_step(hparams.get("mamba_step", 2))

# MLP feed forward parameters (for attention layers)
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))

# Which layers are Mamba layers
# PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer)
# This logic matches modeling_plamo.py's is_mamba function
mamba_step = hparams.get("mamba_step", 2)
mamba_enabled = hparams.get("mamba_enabled", True)
mamba_layers = []

if mamba_enabled:
for i in range(block_count):
if block_count <= (mamba_step // 2):
# use attention in last layer
is_mamba = (i != block_count - 1)
else:
is_mamba = (i % mamba_step) != (mamba_step // 2)
if is_mamba:
mamba_layers.append(i)

if mamba_layers:
self.gguf_writer.add_hybrid_mamba_layers(mamba_layers)

self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.endswith(".embed_tokens.weight"):
# If there is no lm_head, we need to map the token embedding to the output layer
assert self.tensor_names is not None
if all(['lm_head' not in name for name in self.tensor_names]):
name_base = name.replace(".embed_tokens.weight", "")
output_name = "lm_head"

embed_tokens_mapped = self.map_tensor_name(name)
output_mapped = self.map_tensor_name(output_name) + ".weight"

return [(embed_tokens_mapped, data_torch), (output_mapped, data_torch)]
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif name.endswith(".dt_norm_weight"):
name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight"
elif name.endswith(".B_norm_weight"):
name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight"
elif name.endswith(".C_norm_weight"):
name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight"
elif name.endswith(".k_weight"):
name = name.rpartition(".k_weight")[0] + ".k.weight"
elif name.endswith(".q_weight"):
name = name.rpartition(".q_weight")[0] + ".q.weight"
elif name.endswith(".conv1d.weight"):
data_torch = torch.squeeze(data_torch) # remove (, 1, )
assert data_torch.ndim == 2
elif name.endswith(".pre_mixer_norm.weight"):
data_torch += 1.0
elif name.endswith(".post_mixer_norm.weight"):
data_torch += 1.0 / 5
elif name.endswith(".pre_mlp_norm.weight"):
data_torch += 1.0
elif name.endswith(".post_mlp_norm.weight"):
data_torch += 1.0 / (5**1.5)
elif name.endswith(".norm.weight"):
data_torch += 1.0
elif name.endswith(".gate_up_proj.weight"):
# Split the combined gate_up tensor
split_size = data_torch.shape[0] // 2
gate_tensor = data_torch[:split_size, :]
up_tensor = data_torch[split_size:, :]

# Return both tensors - remove .weight suffix if present
name_base = name.replace(".gate_up_proj.weight", "")
gate_name = name_base + ".ffn_gate.weight"
up_name = name_base + ".ffn_up.weight"

gate_mapped = self.map_tensor_name(gate_name)
up_mapped = self.map_tensor_name(up_name)

return [(gate_mapped, gate_tensor), (up_mapped, up_tensor)]

new_name = self.map_tensor_name(name)

return [(new_name, data_torch)]


@ModelBase.register("CodeShellForCausalLM")
class CodeShellModel(TextModel):
model_arch = gguf.MODEL_ARCH.CODESHELL
Expand Down
2 changes: 1 addition & 1 deletion examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {

if (!ggml_is_quantized(t->type)) {
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
ggml_print_tensor(data, t->type, t->ne, t->nb, 256);
}

return true;
Expand Down
55 changes: 55 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ class SSM:
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
DT_MIN = "{arch}.ssm.dt_min"
DT_MAX = "{arch}.ssm.dt_max"
NUM_HEADS = "{arch}.ssm.num_heads"
HEAD_DIM = "{arch}.ssm.head_dim"

class Hybrid:
MAMBA_LAYERS = "{arch}.hybrid.mamba_layers"
MAMBA_STEP = "{arch}.hybrid.mamba_step"

class WKV:
HEAD_SIZE = "{arch}.wkv.head_size"
Expand Down Expand Up @@ -313,6 +321,7 @@ class MODEL_ARCH(IntEnum):
PHI3 = auto()
PHIMOE = auto()
PLAMO = auto()
PLAMO2 = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()
Expand Down Expand Up @@ -433,6 +442,12 @@ class MODEL_TENSOR(IntEnum):
SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto()
SSM_CONV1D_BIAS = auto()
SSM_DT_BIAS = auto()
SSM_BCDT = auto()
SSM_DT_NORM = auto()
SSM_B_NORM = auto()
SSM_C_NORM = auto()
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
Expand Down Expand Up @@ -616,6 +631,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.PLAMO2: "plamo2",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
Expand Down Expand Up @@ -736,6 +752,12 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_CONV1D_BIAS: "blk.{bid}.ssm_conv1d_bias",
MODEL_TENSOR.SSM_DT_BIAS: "blk.{bid}.ssm_dt_bias",
MODEL_TENSOR.SSM_BCDT: "blk.{bid}.ssm_bcdt",
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
Expand Down Expand Up @@ -1342,6 +1364,39 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PLAMO2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_X,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_DT_BIAS,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.SSM_BCDT,
MODEL_TENSOR.SSM_DT_NORM,
MODEL_TENSOR.SSM_B_NORM,
MODEL_TENSOR.SSM_C_NORM,
],
MODEL_ARCH.GPT2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
Expand Down
18 changes: 18 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,24 @@ def add_ssm_group_count(self, value: int) -> None:
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

def add_ssm_dt_min(self, value: float) -> None:
self.add_float32(Keys.SSM.DT_MIN.format(arch=self.arch), value)

def add_ssm_dt_max(self, value: float) -> None:
self.add_float32(Keys.SSM.DT_MAX.format(arch=self.arch), value)

def add_ssm_num_heads(self, value: int) -> None:
self.add_uint32(Keys.SSM.NUM_HEADS.format(arch=self.arch), value)

def add_ssm_head_dim(self, value: int) -> None:
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)

def add_hybrid_mamba_layers(self, layers: list[int]) -> None:
self.add_array(Keys.Hybrid.MAMBA_LAYERS.format(arch=self.arch), layers)

def add_hybrid_mamba_step(self, step: int) -> None:
self.add_uint32(Keys.Hybrid.MAMBA_STEP.format(arch=self.arch), step)

def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

Expand Down
Loading