Skip to content

Commit 4b13c34

Browse files
MollySophiaLaylBongerscompiladeggerganov
authored andcommitted
llama : support RWKV v6 models (ggml-org#8980)
* convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix rwkv tokenizer Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: Add unary operator Exp Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV v6 graph building Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix offloading layers to CUDA Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Remove trailing whitespaces Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Clean up Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update src/llama.cpp Co-authored-by: compilade <git@compilade.net> * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade <git@compilade.net> * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade <git@compilade.net> * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Layl Bongers <3094382+LaylBongers@users.noreply.github.com> Co-authored-by: compilade <git@compilade.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 44ea31d commit 4b13c34

File tree

9 files changed

+1266
-103
lines changed

9 files changed

+1266
-103
lines changed

convert_hf_to_gguf.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import ast
67
import logging
78
import argparse
89
import contextlib
@@ -298,9 +299,12 @@ def prepare_tensors(self):
298299
gguf.MODEL_TENSOR.POS_EMBD,
299300
gguf.MODEL_TENSOR.TOKEN_TYPES,
300301
gguf.MODEL_TENSOR.SSM_CONV1D,
302+
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
303+
gguf.MODEL_TENSOR.TIME_MIX_W1,
304+
gguf.MODEL_TENSOR.TIME_MIX_W2,
301305
)
302306
)
303-
or not name.endswith(".weight")
307+
or not new_name.endswith(".weight")
304308
):
305309
data_qtype = gguf.GGMLQuantizationType.F32
306310

@@ -2716,6 +2720,84 @@ class StarCoder2Model(Model):
27162720
model_arch = gguf.MODEL_ARCH.STARCODER2
27172721

27182722

2723+
@Model.register("Rwkv6ForCausalLM")
2724+
class Rwkv6Model(Model):
2725+
model_arch = gguf.MODEL_ARCH.RWKV6
2726+
2727+
def set_vocab(self):
2728+
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
2729+
vocab_size = self.hparams.get("vocab_size", 65536)
2730+
2731+
tokens: list[bytes] = ['<s>'.encode("utf-8")]
2732+
toktypes: list[int] = [gguf.TokenType.CONTROL]
2733+
2734+
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
2735+
lines = f.readlines()
2736+
for line in lines:
2737+
parts = line.split(' ')
2738+
assert len(parts) >= 3
2739+
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
2740+
token = token.encode("utf-8") if isinstance(token, str) else token
2741+
assert isinstance(token, bytes)
2742+
assert len(token) == token_len
2743+
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
2744+
tokens.append(token_text.encode("utf-8"))
2745+
toktypes.append(gguf.TokenType.NORMAL)
2746+
remainder = vocab_size - len(tokens)
2747+
assert remainder >= 0
2748+
for i in range(len(tokens), vocab_size):
2749+
tokens.append(f"[PAD{i}]".encode("utf-8"))
2750+
toktypes.append(gguf.TokenType.UNUSED)
2751+
2752+
self.gguf_writer.add_tokenizer_model("rwkv")
2753+
self.gguf_writer.add_token_list(tokens)
2754+
self.gguf_writer.add_token_types(toktypes)
2755+
2756+
def set_gguf_parameters(self):
2757+
block_count = self.hparams["num_hidden_layers"]
2758+
head_size = self.hparams["head_size"]
2759+
hidden_size = self.hparams["hidden_size"]
2760+
layer_norm_eps = self.hparams["layer_norm_epsilon"]
2761+
rescale_every_n_layers = self.hparams["rescale_every"]
2762+
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32)
2763+
time_mix_extra_dim = 64 if hidden_size == 4096 else 32
2764+
time_decay_extra_dim = 128 if hidden_size == 4096 else 64
2765+
2766+
# RWKV isn't context limited
2767+
self.gguf_writer.add_context_length(1048576)
2768+
self.gguf_writer.add_embedding_length(hidden_size)
2769+
self.gguf_writer.add_block_count(block_count)
2770+
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
2771+
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
2772+
self.gguf_writer.add_wkv_head_size(head_size)
2773+
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
2774+
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
2775+
self.gguf_writer.add_feed_forward_length(intermediate_size)
2776+
self.gguf_writer.add_file_type(self.ftype)
2777+
2778+
# required by llama.cpp, unused
2779+
self.gguf_writer.add_head_count(0)
2780+
2781+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2782+
new_name = self.map_tensor_name(name)
2783+
2784+
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
2785+
new_name += ".weight"
2786+
2787+
if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
2788+
data_torch = data_torch.transpose(0, 1)
2789+
2790+
if new_name.endswith("time_mix_w2.weight"):
2791+
data_torch = data_torch.permute(0, 2, 1)
2792+
2793+
rescale_every_n_layers = self.hparams["rescale_every"]
2794+
if rescale_every_n_layers > 0:
2795+
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
2796+
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
2797+
2798+
yield (new_name, data_torch)
2799+
2800+
27192801
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
27202802
class MambaModel(Model):
27212803
model_arch = gguf.MODEL_ARCH.MAMBA

ggml/include/ggml.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ extern "C" {
514514
GGML_OP_WIN_UNPART,
515515
GGML_OP_GET_REL_POS,
516516
GGML_OP_ADD_REL_POS,
517+
GGML_OP_RWKV_WKV,
517518

518519
GGML_OP_UNARY,
519520

@@ -548,6 +549,7 @@ extern "C" {
548549
GGML_UNARY_OP_SILU,
549550
GGML_UNARY_OP_HARDSWISH,
550551
GGML_UNARY_OP_HARDSIGMOID,
552+
GGML_UNARY_OP_EXP,
551553

552554
GGML_UNARY_OP_COUNT,
553555
};
@@ -1165,6 +1167,14 @@ extern "C" {
11651167
struct ggml_context * ctx,
11661168
struct ggml_tensor * a);
11671169

1170+
GGML_API struct ggml_tensor * ggml_exp(
1171+
struct ggml_context * ctx,
1172+
struct ggml_tensor * a);
1173+
1174+
GGML_API struct ggml_tensor * ggml_exp_inplace(
1175+
struct ggml_context * ctx,
1176+
struct ggml_tensor * a);
1177+
11681178
// normalize along rows
11691179
GGML_API struct ggml_tensor * ggml_norm(
11701180
struct ggml_context * ctx,
@@ -1913,6 +1923,15 @@ extern "C" {
19131923
struct ggml_tensor * pw,
19141924
struct ggml_tensor * ph);
19151925

1926+
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
1927+
struct ggml_context * ctx,
1928+
struct ggml_tensor * k,
1929+
struct ggml_tensor * v,
1930+
struct ggml_tensor * r,
1931+
struct ggml_tensor * tf,
1932+
struct ggml_tensor * td,
1933+
struct ggml_tensor * state);
1934+
19161935
// custom operators
19171936

19181937
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

0 commit comments

Comments
 (0)