Skip to content

llama: add initial support for Falcon-H1 model family #14534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 26 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 153 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e":
# ref: https://huggingface.co/tiiuae/Falcon3-7B-Base
res = "falcon3"
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
# ref: https://huggingface.co/collections/tiiuae/falcon-h1-6819f2795bc406da60fab8df
res = "falcon_h1"
if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7":
# ref: https://huggingface.co/BAAI/bge-large-zh-v1.5
res = "bert-bge-large"
Expand Down Expand Up @@ -4879,6 +4882,9 @@ def set_vocab(self):
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# if architecture is FalconH1, don't pad vocab size
if self.hparams.get("architectures", [None])[0] == "FalconH1ForCausalLM":
pad_vocab = 1
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size

Expand All @@ -4905,8 +4911,11 @@ def set_gguf_parameters(self):

# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
# skip the assertion for FalconH1 Model
architectures = self.hparams.get("architectures")
if architectures is None or architectures[0] != "FalconH1ForCausalLM":
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
Comment on lines 4913 to +4918
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That check isn't strictly necessary anyway, and probably should be removed altogether (and also in src/llama-model.cpp).

If you want to keep it for now, but not for Falcon-H1, does self.model_arch correspond to gguf.MODEL_ARCH.FALCON_H1 when it's that arch? (Might be simpler than reading hparams)


self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
Expand Down Expand Up @@ -4945,6 +4954,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
architectures = self.hparams.get("architectures")
if architectures is not None and architectures[0] == "FalconH1ForCausalLM":
# FalconH1F has a different d_inner
d_inner = self.hparams.get("mamba_d_ssm")
Comment on lines 4955 to +4960
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another (maybe simpler) approach would be to add "mamba_d_ssm" to the find_hparams call for d_inner above.

data_torch = data_torch.reshape((n_group, d_inner // n_group))

if name.endswith(".A_log"):
Expand Down Expand Up @@ -6535,6 +6548,144 @@ def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])


@ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1

def __init__(self, *args, **kwargs):
# Set the hparam prefixes for Falcon Mamba2
self.hparam_prefixes = ["mamba"]

# Initialize the base Mamba2Model
super().__init__(*args, **kwargs)

# Use Llama conversion for attention
self._transformer_model_class = LlamaModel

# n_group and d_inner are used during reshape_tensors for mamaba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model

# Initialize any Falcon Mamba2 specific attributes
self.has_attention = True # Falcon Mamba2 has attention components

# Load Falcon-H1 multipliers from hyperparameters
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
self.intermediate_size = self.find_hparam(["intermediate_size"])

def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)

def _generate_mup_vector(self, block_id: int) -> torch.Tensor:
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
intermediate_size = self.hparams["mamba_d_ssm"]
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"])

mup_vector = torch.ones(1, 1, vector_shape)
mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0]
mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1]
mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2]
mup_vector[:, :, 2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size] *= zxbcdt_multipliers[3]
mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size:] *= zxbcdt_multipliers[4]

return mup_vector

def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, tensor in super().get_tensors():
Comment on lines +6608 to +6609
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, modify_tensors should be overridden instead of get_tensors to rename and/or insert tensors.

if name.startswith("model.backbone") or name.startswith("model.lm_head"):
name = name.removeprefix("model.")
yield name, tensor

if self.ssm_multipliers is not None:
# Insert MUP vector after mamba.dt_bias
if "mamba.dt_bias" in name:
block_match = re.search(r"(?:model\.layers\.)?(\d+)\.mamba\.dt_bias", name)
if block_match:
block_id = int(block_match.group(1))
# Generate MUP vector with correct name format
mup_tensor = self._generate_mup_vector(block_id)
mup_name = f"blk.{block_id}.ssm_mup_vec"
logger.debug(f"Inserting MUP vector for block {block_id}: {mup_name}")
yield mup_name, mup_tensor

def set_gguf_parameters(self):
super().set_gguf_parameters()

## General Params ##
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])

## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))

## Attention params ##
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in self.hparams else self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_key_length(self.hparams["head_dim"])
self.gguf_writer.add_value_length(self.hparams["head_dim"])
self.gguf_writer.add_float64("falcon_h1.key_multiplier", self.hparams["key_multiplier"])

## Other params
self.gguf_writer.add_float64("falcon_h1.lm_head_multiplier", self.hparams["lm_head_multiplier"])
self.gguf_writer.add_float64("falcon_h1.embedding_multiplier", self.hparams["embedding_multiplier"])

## Validation ##
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"


# Add Falcon Mamba2 specific configuration
self.gguf_writer.add_uint32("falcon_h1.ssm.mamba_chunk_size", self.hparams["mamba_chunk_size"])
self.gguf_writer.add_uint32("falcon_h1.attention.head_dim", self.hparams["head_dim"])
self.gguf_writer.add_uint32("falcon_h1.ssm.mamba_d_ssm", self.hparams["mamba_d_ssm"])
self.gguf_writer.add_uint32("falcon_h1.num_attention_heads", self.find_hparam(["num_attention_heads"]))
self.gguf_writer.add_uint32("falcon_h1.num_key_value_heads",
self.find_hparam(["num_key_value_heads"], optional=True) or
self.find_hparam(["num_attention_heads"]))

# Add multipliers as metadata instead of tensors
self.gguf_writer.add_float64("falcon_h1.attention_in_multiplier", self.attention_in_multiplier)
self.gguf_writer.add_float64("falcon_h1.attention_out_multiplier", self.attention_out_multiplier)
self.gguf_writer.add_float64("falcon_h1.ssm_in_multiplier", self.ssm_in_multiplier)
self.gguf_writer.add_float64("falcon_h1.ssm_out_multiplier", self.ssm_out_multiplier)

# Add MLP multipliers
if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2:
self.gguf_writer.add_float64("falcon_h1.mlp_gate_multiplier", self.mlp_multipliers[0])
self.gguf_writer.add_float64("falcon_h1.mlp_down_multiplier", self.mlp_multipliers[1])

# Add has MuP flag if SSM multipliers are present
if self.ssm_multipliers is not None:
self.gguf_writer.add_bool("falcon_h1.ssm.has_mup", True)

# Add any other Falcon Mamba2 specific configuration
self.gguf_writer.add_bool("falcon_h1.mamba_use_mlp", self.find_hparam(["mamba_use_mlp"], optional=True))
self.gguf_writer.add_bool("falcon_h1.mamba_norm_before_gate", self.find_hparam(["mamba_norm_before_gate"], optional=True))
self.gguf_writer.add_bool("falcon_h1.mamba_rms_norm", self.find_hparam(["mamba_rms_norm"], optional=True))
self.gguf_writer.add_float64("falcon_h1.rope_theta", self.find_hparam(["rope_theta"], optional=True))

###### CONVERSION LOGIC ######


Expand Down
40 changes: 40 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class SSM:
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
HEAD_DIM = "{arch}.ssm.head_dim"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The head dimension in Mamba-2 is also the time step rank.

I guess it could be clearer to use a more appropriate name like this, though.

I'm not against, this is only to at least let you know.


class WKV:
HEAD_SIZE = "{arch}.wkv.head_size"
Expand Down Expand Up @@ -288,6 +289,7 @@ class MODEL_ARCH(IntEnum):
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
FALCON_H1 = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
Expand Down Expand Up @@ -525,6 +527,7 @@ class MODEL_TENSOR(IntEnum):
POSNET_ATTN_K = auto()
POSNET_ATTN_V = auto()
POSNET_ATTN_OUT = auto()
SSM_MUP_VEC = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
Expand Down Expand Up @@ -660,6 +663,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.FALCON_H1: "falcon_h1",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -736,6 +740,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_MUP_VEC: "blk.{bid}.ssm_mup_vec",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
Expand Down Expand Up @@ -2211,6 +2216,41 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.FALCON_H1: [
# Token embedding
MODEL_TENSOR.TOKEN_EMBD,

# Input layernorm
MODEL_TENSOR.ATTN_NORM,

# Attention components
MODEL_TENSOR.ATTN_Q, # Query projection
MODEL_TENSOR.ATTN_K, # Key projection
MODEL_TENSOR.ATTN_V, # Value projection
MODEL_TENSOR.ATTN_OUT, # Output projection

# SSM components (Mamba2 specific)
MODEL_TENSOR.SSM_MUP_VEC, # Mup vector
MODEL_TENSOR.SSM_IN, # Input projection for SSM
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
MODEL_TENSOR.SSM_DT, # Delta time projection
MODEL_TENSOR.SSM_A, # A parameter (log form)
MODEL_TENSOR.SSM_D, # D parameter
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
MODEL_TENSOR.SSM_OUT, # Output projection

# Pre-feedforward layernorm
MODEL_TENSOR.FFN_PRE_NORM,

# Feed-forward network components
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
MODEL_TENSOR.FFN_DOWN, # Down projection
MODEL_TENSOR.FFN_UP, # Up projection

# Post-feedforward layernorm
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
],
# TODO
}

Expand Down
9 changes: 9 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,15 @@ def add_ssm_group_count(self, value: int) -> None:
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

def add_ssm_head_dim(self, value: int) -> None:
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)

def add_attn_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)

def add_key_value_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)

Comment on lines +873 to +878
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't these already exist?

def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

Expand Down
17 changes: 17 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,14 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
"model.layers.{bid}.pre_ff_layernorm.weight",
),

# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.feed_forward.up_proj",
),

MODEL_TENSOR.FFN_GATE_INP: (
Expand Down Expand Up @@ -362,6 +364,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.feed_forward.down_proj",
),

# AWQ-activation gate
Expand Down Expand Up @@ -547,11 +550,13 @@ class TensorNameMap:
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.mamba.in_proj",
),

MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.mamba.conv1d",
),

MODEL_TENSOR.SSM_X: (
Expand All @@ -562,16 +567,19 @@ class TensorNameMap:
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.mamba.dt_proj",
),

MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.mamba.A_log",
),

MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.mamba.D",
),

MODEL_TENSOR.SSM_NORM: (
Expand All @@ -581,6 +589,7 @@ class TensorNameMap:
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
"model.layers.{bid}.mamba.out_proj", # falcon-h1
),

MODEL_TENSOR.TIME_MIX_W0: (
Expand Down Expand Up @@ -1168,6 +1177,14 @@ class TensorNameMap:
"resampler.attn.out_proj",
),

MODEL_TENSOR.SSM_MUP_VEC: (
"model.layers.{bid}.mamba.mup_vector", # falcon_h1
),

MODEL_TENSOR.SSM_NORM: (
"model.layers.{bid}.mamba.norm",
),

MODEL_TENSOR.V_RESMPL_KV: (
"resampler.kv_proj",
),
Expand Down
Loading
Loading