|
3 | 3 |
|
4 | 4 | from __future__ import annotations
|
5 | 5 |
|
| 6 | +import ast |
6 | 7 | import logging
|
7 | 8 | import argparse
|
8 | 9 | import contextlib
|
@@ -298,9 +299,12 @@ def prepare_tensors(self):
|
298 | 299 | gguf.MODEL_TENSOR.POS_EMBD,
|
299 | 300 | gguf.MODEL_TENSOR.TOKEN_TYPES,
|
300 | 301 | gguf.MODEL_TENSOR.SSM_CONV1D,
|
| 302 | + gguf.MODEL_TENSOR.TIME_MIX_FIRST, |
| 303 | + gguf.MODEL_TENSOR.TIME_MIX_W1, |
| 304 | + gguf.MODEL_TENSOR.TIME_MIX_W2, |
301 | 305 | )
|
302 | 306 | )
|
303 |
| - or not name.endswith(".weight") |
| 307 | + or not new_name.endswith(".weight") |
304 | 308 | ):
|
305 | 309 | data_qtype = gguf.GGMLQuantizationType.F32
|
306 | 310 |
|
@@ -2716,6 +2720,84 @@ class StarCoder2Model(Model):
|
2716 | 2720 | model_arch = gguf.MODEL_ARCH.STARCODER2
|
2717 | 2721 |
|
2718 | 2722 |
|
| 2723 | +@Model.register("Rwkv6ForCausalLM") |
| 2724 | +class Rwkv6Model(Model): |
| 2725 | + model_arch = gguf.MODEL_ARCH.RWKV6 |
| 2726 | + |
| 2727 | + def set_vocab(self): |
| 2728 | + assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file() |
| 2729 | + vocab_size = self.hparams.get("vocab_size", 65536) |
| 2730 | + |
| 2731 | + tokens: list[bytes] = ['<s>'.encode("utf-8")] |
| 2732 | + toktypes: list[int] = [gguf.TokenType.CONTROL] |
| 2733 | + |
| 2734 | + with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f: |
| 2735 | + lines = f.readlines() |
| 2736 | + for line in lines: |
| 2737 | + parts = line.split(' ') |
| 2738 | + assert len(parts) >= 3 |
| 2739 | + token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1]) |
| 2740 | + token = token.encode("utf-8") if isinstance(token, str) else token |
| 2741 | + assert isinstance(token, bytes) |
| 2742 | + assert len(token) == token_len |
| 2743 | + token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff" |
| 2744 | + tokens.append(token_text.encode("utf-8")) |
| 2745 | + toktypes.append(gguf.TokenType.NORMAL) |
| 2746 | + remainder = vocab_size - len(tokens) |
| 2747 | + assert remainder >= 0 |
| 2748 | + for i in range(len(tokens), vocab_size): |
| 2749 | + tokens.append(f"[PAD{i}]".encode("utf-8")) |
| 2750 | + toktypes.append(gguf.TokenType.UNUSED) |
| 2751 | + |
| 2752 | + self.gguf_writer.add_tokenizer_model("rwkv") |
| 2753 | + self.gguf_writer.add_token_list(tokens) |
| 2754 | + self.gguf_writer.add_token_types(toktypes) |
| 2755 | + |
| 2756 | + def set_gguf_parameters(self): |
| 2757 | + block_count = self.hparams["num_hidden_layers"] |
| 2758 | + head_size = self.hparams["head_size"] |
| 2759 | + hidden_size = self.hparams["hidden_size"] |
| 2760 | + layer_norm_eps = self.hparams["layer_norm_epsilon"] |
| 2761 | + rescale_every_n_layers = self.hparams["rescale_every"] |
| 2762 | + intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32) |
| 2763 | + time_mix_extra_dim = 64 if hidden_size == 4096 else 32 |
| 2764 | + time_decay_extra_dim = 128 if hidden_size == 4096 else 64 |
| 2765 | + |
| 2766 | + # RWKV isn't context limited |
| 2767 | + self.gguf_writer.add_context_length(1048576) |
| 2768 | + self.gguf_writer.add_embedding_length(hidden_size) |
| 2769 | + self.gguf_writer.add_block_count(block_count) |
| 2770 | + self.gguf_writer.add_layer_norm_eps(layer_norm_eps) |
| 2771 | + self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) |
| 2772 | + self.gguf_writer.add_wkv_head_size(head_size) |
| 2773 | + self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) |
| 2774 | + self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) |
| 2775 | + self.gguf_writer.add_feed_forward_length(intermediate_size) |
| 2776 | + self.gguf_writer.add_file_type(self.ftype) |
| 2777 | + |
| 2778 | + # required by llama.cpp, unused |
| 2779 | + self.gguf_writer.add_head_count(0) |
| 2780 | + |
| 2781 | + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |
| 2782 | + new_name = self.map_tensor_name(name) |
| 2783 | + |
| 2784 | + if not (new_name.endswith(".weight") or new_name.endswith(".bias")): |
| 2785 | + new_name += ".weight" |
| 2786 | + |
| 2787 | + if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"): |
| 2788 | + data_torch = data_torch.transpose(0, 1) |
| 2789 | + |
| 2790 | + if new_name.endswith("time_mix_w2.weight"): |
| 2791 | + data_torch = data_torch.permute(0, 2, 1) |
| 2792 | + |
| 2793 | + rescale_every_n_layers = self.hparams["rescale_every"] |
| 2794 | + if rescale_every_n_layers > 0: |
| 2795 | + if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"): |
| 2796 | + data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers)) |
| 2797 | + |
| 2798 | + yield (new_name, data_torch) |
| 2799 | + |
| 2800 | + |
2719 | 2801 | @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
2720 | 2802 | class MambaModel(Model):
|
2721 | 2803 | model_arch = gguf.MODEL_ARCH.MAMBA
|
|
0 commit comments