-
Notifications
You must be signed in to change notification settings - Fork 12.3k
llama: add initial support for Falcon-H1 model family #14534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
991de6c
f897efd
71a6848
03568c9
0c93ef6
fdd5cff
14c37ec
8bea922
071f4b7
50eadc7
a39a842
1415cd8
243e4d1
cce3549
22de62c
2fe057c
d22b4ea
6c7d9e2
15138df
a6d0067
1fd0574
250b4f1
3ee7983
2aa48dd
9760c8b
7a25441
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -686,6 +686,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: | |
if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": | ||
# ref: https://huggingface.co/tiiuae/Falcon3-7B-Base | ||
res = "falcon3" | ||
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86": | ||
# ref: https://huggingface.co/collections/tiiuae/falcon-h1-6819f2795bc406da60fab8df | ||
res = "falcon_h1" | ||
if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7": | ||
# ref: https://huggingface.co/BAAI/bge-large-zh-v1.5 | ||
res = "bert-bge-large" | ||
|
@@ -4879,6 +4882,9 @@ def set_vocab(self): | |
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) | ||
# pad using ceiling division | ||
# ref: https://stackoverflow.com/a/17511341/22827863 | ||
# if architecture is FalconH1, don't pad vocab size | ||
if self.hparams.get("architectures", [None])[0] == "FalconH1ForCausalLM": | ||
pad_vocab = 1 | ||
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab | ||
self.hparams["vocab_size"] = vocab_size | ||
|
||
|
@@ -4905,8 +4911,11 @@ def set_gguf_parameters(self): | |
|
||
# Fail early for models which don't have a block expansion factor of 2 | ||
# TODO: does this really matter? | ||
assert d_inner == 2 * d_model | ||
assert d_inner % head_dim == 0 | ||
# skip the assertion for FalconH1 Model | ||
architectures = self.hparams.get("architectures") | ||
if architectures is None or architectures[0] != "FalconH1ForCausalLM": | ||
assert d_inner == 2 * d_model | ||
assert d_inner % head_dim == 0 | ||
|
||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default | ||
self.gguf_writer.add_embedding_length(d_model) | ||
|
@@ -4945,6 +4954,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter | |
d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) | ||
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model | ||
n_group = self.hparams.get("n_groups", 1) | ||
architectures = self.hparams.get("architectures") | ||
if architectures is not None and architectures[0] == "FalconH1ForCausalLM": | ||
# FalconH1F has a different d_inner | ||
d_inner = self.hparams.get("mamba_d_ssm") | ||
Comment on lines
4955
to
+4960
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another (maybe simpler) approach would be to add |
||
data_torch = data_torch.reshape((n_group, d_inner // n_group)) | ||
|
||
if name.endswith(".A_log"): | ||
|
@@ -6535,6 +6548,144 @@ def set_gguf_parameters(self): | |
super().set_gguf_parameters() | ||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) | ||
|
||
|
||
@ModelBase.register("FalconH1ForCausalLM") | ||
class FalconH1Model(Mamba2Model): | ||
model_arch = gguf.MODEL_ARCH.FALCON_H1 | ||
|
||
def __init__(self, *args, **kwargs): | ||
# Set the hparam prefixes for Falcon Mamba2 | ||
self.hparam_prefixes = ["mamba"] | ||
|
||
# Initialize the base Mamba2Model | ||
super().__init__(*args, **kwargs) | ||
|
||
# Use Llama conversion for attention | ||
self._transformer_model_class = LlamaModel | ||
|
||
# n_group and d_inner are used during reshape_tensors for mamaba2 | ||
self.d_model = self.find_hparam(["hidden_size", "d_model"]) | ||
self.n_group = self.find_hparam(["n_groups"]) | ||
self.d_inner = self.find_hparam(["expand"]) * self.d_model | ||
|
||
# Initialize any Falcon Mamba2 specific attributes | ||
self.has_attention = True # Falcon Mamba2 has attention components | ||
|
||
# Load Falcon-H1 multipliers from hyperparameters | ||
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True) | ||
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True) | ||
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True) | ||
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True) | ||
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True) | ||
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True) | ||
self.intermediate_size = self.find_hparam(["intermediate_size"]) | ||
|
||
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: | ||
prefixed = [] | ||
for pfx in self.hparam_prefixes: | ||
prefixed.extend( | ||
"_".join([pfx, k]) | ||
for k in keys | ||
) | ||
keys = list(keys) + prefixed | ||
return super().find_hparam(keys, *args, **kwargs) | ||
|
||
def _generate_mup_vector(self, block_id: int) -> torch.Tensor: | ||
zxbcdt_multipliers = self.hparams["ssm_multipliers"] | ||
intermediate_size = self.hparams["mamba_d_ssm"] | ||
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"] | ||
vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"]) | ||
|
||
mup_vector = torch.ones(1, 1, vector_shape) | ||
mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0] | ||
mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1] | ||
mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2] | ||
mup_vector[:, :, 2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size] *= zxbcdt_multipliers[3] | ||
mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size:] *= zxbcdt_multipliers[4] | ||
|
||
return mup_vector | ||
|
||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]: | ||
for name, tensor in super().get_tensors(): | ||
Comment on lines
+6608
to
+6609
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If possible, |
||
if name.startswith("model.backbone") or name.startswith("model.lm_head"): | ||
name = name.removeprefix("model.") | ||
yield name, tensor | ||
|
||
if self.ssm_multipliers is not None: | ||
# Insert MUP vector after mamba.dt_bias | ||
if "mamba.dt_bias" in name: | ||
block_match = re.search(r"(?:model\.layers\.)?(\d+)\.mamba\.dt_bias", name) | ||
if block_match: | ||
block_id = int(block_match.group(1)) | ||
# Generate MUP vector with correct name format | ||
mup_tensor = self._generate_mup_vector(block_id) | ||
mup_name = f"blk.{block_id}.ssm_mup_vec" | ||
logger.debug(f"Inserting MUP vector for block {block_id}: {mup_name}") | ||
yield mup_name, mup_tensor | ||
|
||
def set_gguf_parameters(self): | ||
super().set_gguf_parameters() | ||
|
||
## General Params ## | ||
self.gguf_writer.add_block_count(self.block_count) | ||
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) | ||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) | ||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) | ||
|
||
## Mamba mixer params ## | ||
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) | ||
self.gguf_writer.add_ssm_group_count(self.n_group) | ||
self.gguf_writer.add_ssm_inner_size(self.d_inner) | ||
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"])) | ||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) | ||
|
||
## Attention params ## | ||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) | ||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in self.hparams else self.hparams["num_attention_heads"]) | ||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) | ||
self.gguf_writer.add_key_length(self.hparams["head_dim"]) | ||
self.gguf_writer.add_value_length(self.hparams["head_dim"]) | ||
self.gguf_writer.add_float64("falcon_h1.key_multiplier", self.hparams["key_multiplier"]) | ||
|
||
## Other params | ||
self.gguf_writer.add_float64("falcon_h1.lm_head_multiplier", self.hparams["lm_head_multiplier"]) | ||
self.gguf_writer.add_float64("falcon_h1.embedding_multiplier", self.hparams["embedding_multiplier"]) | ||
|
||
## Validation ## | ||
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" | ||
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" | ||
|
||
|
||
# Add Falcon Mamba2 specific configuration | ||
self.gguf_writer.add_uint32("falcon_h1.ssm.mamba_chunk_size", self.hparams["mamba_chunk_size"]) | ||
self.gguf_writer.add_uint32("falcon_h1.attention.head_dim", self.hparams["head_dim"]) | ||
self.gguf_writer.add_uint32("falcon_h1.ssm.mamba_d_ssm", self.hparams["mamba_d_ssm"]) | ||
self.gguf_writer.add_uint32("falcon_h1.num_attention_heads", self.find_hparam(["num_attention_heads"])) | ||
self.gguf_writer.add_uint32("falcon_h1.num_key_value_heads", | ||
self.find_hparam(["num_key_value_heads"], optional=True) or | ||
self.find_hparam(["num_attention_heads"])) | ||
|
||
# Add multipliers as metadata instead of tensors | ||
self.gguf_writer.add_float64("falcon_h1.attention_in_multiplier", self.attention_in_multiplier) | ||
self.gguf_writer.add_float64("falcon_h1.attention_out_multiplier", self.attention_out_multiplier) | ||
self.gguf_writer.add_float64("falcon_h1.ssm_in_multiplier", self.ssm_in_multiplier) | ||
self.gguf_writer.add_float64("falcon_h1.ssm_out_multiplier", self.ssm_out_multiplier) | ||
|
||
# Add MLP multipliers | ||
if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2: | ||
self.gguf_writer.add_float64("falcon_h1.mlp_gate_multiplier", self.mlp_multipliers[0]) | ||
self.gguf_writer.add_float64("falcon_h1.mlp_down_multiplier", self.mlp_multipliers[1]) | ||
|
||
# Add has MuP flag if SSM multipliers are present | ||
if self.ssm_multipliers is not None: | ||
self.gguf_writer.add_bool("falcon_h1.ssm.has_mup", True) | ||
|
||
# Add any other Falcon Mamba2 specific configuration | ||
self.gguf_writer.add_bool("falcon_h1.mamba_use_mlp", self.find_hparam(["mamba_use_mlp"], optional=True)) | ||
self.gguf_writer.add_bool("falcon_h1.mamba_norm_before_gate", self.find_hparam(["mamba_norm_before_gate"], optional=True)) | ||
self.gguf_writer.add_bool("falcon_h1.mamba_rms_norm", self.find_hparam(["mamba_rms_norm"], optional=True)) | ||
self.gguf_writer.add_float64("falcon_h1.rope_theta", self.find_hparam(["rope_theta"], optional=True)) | ||
|
||
###### CONVERSION LOGIC ###### | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -172,6 +172,7 @@ class SSM: | |
TIME_STEP_RANK = "{arch}.ssm.time_step_rank" | ||
GROUP_COUNT = "{arch}.ssm.group_count" | ||
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" | ||
HEAD_DIM = "{arch}.ssm.head_dim" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The head dimension in Mamba-2 is also the time step rank. I guess it could be clearer to use a more appropriate name like this, though. I'm not against, this is only to at least let you know. |
||
|
||
class WKV: | ||
HEAD_SIZE = "{arch}.wkv.head_size" | ||
|
@@ -288,6 +289,7 @@ class MODEL_ARCH(IntEnum): | |
LLAMA4 = auto() | ||
DECI = auto() | ||
FALCON = auto() | ||
FALCON_H1 = auto() | ||
BAICHUAN = auto() | ||
GROK = auto() | ||
GPT2 = auto() | ||
|
@@ -525,6 +527,7 @@ class MODEL_TENSOR(IntEnum): | |
POSNET_ATTN_K = auto() | ||
POSNET_ATTN_V = auto() | ||
POSNET_ATTN_OUT = auto() | ||
SSM_MUP_VEC = auto() | ||
# vision | ||
V_MMPROJ = auto() | ||
V_MMPROJ_FC = auto() | ||
|
@@ -660,6 +663,7 @@ class MODEL_TENSOR(IntEnum): | |
MODEL_ARCH.DOTS1: "dots1", | ||
MODEL_ARCH.ARCEE: "arcee", | ||
MODEL_ARCH.ERNIE4_5: "ernie4_5", | ||
MODEL_ARCH.FALCON_H1: "falcon_h1", | ||
} | ||
|
||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { | ||
|
@@ -736,6 +740,7 @@ class MODEL_TENSOR(IntEnum): | |
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", | ||
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", | ||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", | ||
MODEL_TENSOR.SSM_MUP_VEC: "blk.{bid}.ssm_mup_vec", | ||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", | ||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", | ||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", | ||
|
@@ -2211,6 +2216,41 @@ class MODEL_TENSOR(IntEnum): | |
MODEL_TENSOR.FFN_DOWN, | ||
MODEL_TENSOR.FFN_UP, | ||
], | ||
MODEL_ARCH.FALCON_H1: [ | ||
# Token embedding | ||
MODEL_TENSOR.TOKEN_EMBD, | ||
|
||
# Input layernorm | ||
MODEL_TENSOR.ATTN_NORM, | ||
|
||
# Attention components | ||
MODEL_TENSOR.ATTN_Q, # Query projection | ||
MODEL_TENSOR.ATTN_K, # Key projection | ||
MODEL_TENSOR.ATTN_V, # Value projection | ||
MODEL_TENSOR.ATTN_OUT, # Output projection | ||
|
||
# SSM components (Mamba2 specific) | ||
MODEL_TENSOR.SSM_MUP_VEC, # Mup vector | ||
MODEL_TENSOR.SSM_IN, # Input projection for SSM | ||
MODEL_TENSOR.SSM_CONV1D, # Convolution layer | ||
MODEL_TENSOR.SSM_DT, # Delta time projection | ||
MODEL_TENSOR.SSM_A, # A parameter (log form) | ||
MODEL_TENSOR.SSM_D, # D parameter | ||
MODEL_TENSOR.SSM_NORM, # Normalization in SSM | ||
MODEL_TENSOR.SSM_OUT, # Output projection | ||
|
||
# Pre-feedforward layernorm | ||
MODEL_TENSOR.FFN_PRE_NORM, | ||
|
||
# Feed-forward network components | ||
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU) | ||
MODEL_TENSOR.FFN_DOWN, # Down projection | ||
MODEL_TENSOR.FFN_UP, # Up projection | ||
|
||
# Post-feedforward layernorm | ||
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm | ||
MODEL_TENSOR.OUTPUT, # Output projection (lm_head) | ||
], | ||
# TODO | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -867,6 +867,15 @@ def add_ssm_group_count(self, value: int) -> None: | |
def add_ssm_dt_b_c_rms(self, value: bool) -> None: | ||
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) | ||
|
||
def add_ssm_head_dim(self, value: int) -> None: | ||
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value) | ||
|
||
def add_attn_head_count(self, count: int) -> None: | ||
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) | ||
|
||
def add_key_value_head_count(self, count: int) -> None: | ||
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) | ||
|
||
Comment on lines
+873
to
+878
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't these already exist? |
||
def add_tokenizer_model(self, model: str) -> None: | ||
self.add_string(Keys.Tokenizer.MODEL, model) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That check isn't strictly necessary anyway, and probably should be removed altogether (and also in
src/llama-model.cpp
).If you want to keep it for now, but not for Falcon-H1, does
self.model_arch
correspond togguf.MODEL_ARCH.FALCON_H1
when it's that arch? (Might be simpler than reading hparams)