Skip to content

Commit cb03b4f

Browse files
committed
Merge remote-tracking branch 'origin/master' into gabe-l-hart/nvidia-nemotron-nano-15409
* origin/master: (59 commits) scripts: add sqlite3 check for compare-commits.sh (ggml-org#15633) kv-cache : remove LLAMA_SET_ROWS checks (ggml-org#15505) gguf-py: byteswapping improvements (ggml-org#12851) cli : change log to warning to explain reason for stopping (ggml-org#15604) model-conversion : add mmproj conversion target (ggml-org#15628) cuda: Add cublasLt_static linking when GGML_STATIC is enabled (ggml-org#15622) server: higher timeout for tests (ggml-org#15621) presets : add qwen3-30B-a3b FIM (ggml-org#15616) HIP: Enable support for ggml_backend_cuda_register_host_buffer (ggml-org#15615) kv-cache : better estimate of n_kv for multi-sequence batches (ggml-org#15610) CANN: refactor mask handling and improve performance in FA (ggml-org#15561) ggml-cpu : add basic RVV support for vector f32 ops (ggml-org#15057) common : add -m to bash completion for --model [no ci] (ggml-org#15591) OpenCL: add fused group_norm/norm, mul, add (ggml-org#15314) tests : fix test-opt with GGML_BACKEND_DL (ggml-org#15599) SYCL: fix rms_norm_mul_add for tensor dim not a multiple of sg_size (ggml-org#15592) mtmd : fix mtmd ios build (ggml-org#15579) tests: add performance test for mul mat id (ggml-org#15543) llamafile: PowerPC Sgemm Optimization (ggml-org#15558) graph : fix assert in memory-less build_attn (ggml-org#15590) ...
2 parents 9d4e0d7 + 55042b3 commit cb03b4f

File tree

86 files changed

+4328
-1808
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+4328
-1808
lines changed

.devops/vulkan.Dockerfile

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,30 @@ ARG UBUNTU_VERSION=24.04
22

33
FROM ubuntu:$UBUNTU_VERSION AS build
44

5-
# Install build tools
6-
RUN apt update && apt install -y git build-essential cmake wget
5+
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
76

8-
# Install Vulkan SDK and cURL
9-
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
10-
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
11-
apt update -y && \
12-
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
7+
# Install build tools
8+
RUN apt update && apt install -y git build-essential cmake wget xz-utils
9+
10+
# Install Vulkan SDK
11+
ARG VULKAN_VERSION=1.4.321.1
12+
RUN ARCH=$(uname -m) && \
13+
wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \
14+
mkdir -p /opt/vulkan && \
15+
tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \
16+
mv /tmp/${ARCH}/* /opt/vulkan/ && \
17+
rm -rf /tmp/*
18+
19+
# Install cURL and Vulkan SDK dependencies
20+
RUN apt install -y libcurl4-openssl-dev curl \
21+
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev
22+
23+
# Set environment variables
24+
ENV VULKAN_SDK=/opt/vulkan
25+
ENV PATH=$VULKAN_SDK/bin:$PATH
26+
ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH
27+
ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH
28+
ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH
1329

1430
# Build it
1531
WORKDIR /app

common/arg.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
11061106
printf("\"\n\n");
11071107

11081108
printf(" case \"$prev\" in\n");
1109-
printf(" --model)\n");
1109+
printf(" --model|-m)\n");
11101110
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
11111111
printf(" return 0\n");
11121112
printf(" ;;\n");
@@ -3538,6 +3538,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35383538
}
35393539
).set_examples({LLAMA_EXAMPLE_SERVER}));
35403540

3541+
add_opt(common_arg(
3542+
{"--fim-qwen-30b-default"},
3543+
string_format("use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet)"),
3544+
[](common_params & params) {
3545+
params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF";
3546+
params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
3547+
params.port = 8012;
3548+
params.n_gpu_layers = 99;
3549+
params.flash_attn = true;
3550+
params.n_ubatch = 1024;
3551+
params.n_batch = 1024;
3552+
params.n_ctx = 0;
3553+
params.n_cache_reuse = 256;
3554+
}
3555+
).set_examples({LLAMA_EXAMPLE_SERVER}));
3556+
35413557
add_opt(common_arg(
35423558
{ "--diffusion-steps" }, "N",
35433559
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),

convert_hf_to_gguf.py

Lines changed: 113 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,55 @@ def _try_set_pooling_type(self) -> None:
12161216
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
12171217
self.gguf_writer.add_pooling_type(pooling_type)
12181218

1219+
def _set_vocab_interns1(self):
1220+
tokens: list[str] = []
1221+
toktypes: list[int] = []
1222+
1223+
from transformers import AutoTokenizer
1224+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
1225+
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
1226+
vocab_size = self.hparams.get("vocab_size", len(vocab))
1227+
assert max(vocab.values()) < vocab_size
1228+
1229+
tokpre = self.get_vocab_base_pre(tokenizer)
1230+
1231+
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
1232+
added_vocab = tokenizer.get_added_vocab()
1233+
1234+
added_tokens_decoder = tokenizer.added_tokens_decoder
1235+
1236+
for i in range(vocab_size):
1237+
if i not in reverse_vocab:
1238+
tokens.append(f"[PAD{i}]")
1239+
toktypes.append(gguf.TokenType.UNUSED)
1240+
else:
1241+
token: str = reverse_vocab[i]
1242+
if token in added_vocab:
1243+
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
1244+
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
1245+
if not added_tokens_decoder[i].normalized:
1246+
previous_token = token
1247+
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
1248+
if previous_token != token:
1249+
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
1250+
1251+
if added_tokens_decoder[i].special or self.does_token_look_special(token):
1252+
toktypes.append(gguf.TokenType.CONTROL)
1253+
else:
1254+
toktypes.append(gguf.TokenType.USER_DEFINED)
1255+
else:
1256+
toktypes.append(gguf.TokenType.NORMAL)
1257+
tokens.append(token)
1258+
1259+
self.gguf_writer.add_tokenizer_model("gpt2")
1260+
self.gguf_writer.add_tokenizer_pre(tokpre)
1261+
self.gguf_writer.add_token_list(tokens)
1262+
self.gguf_writer.add_token_types(toktypes)
1263+
1264+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
1265+
special_vocab._set_special_token("bos", 151643)
1266+
special_vocab.add_to_gguf(self.gguf_writer)
1267+
12191268

12201269
class MmprojModel(ModelBase):
12211270
model_type = ModelType.MMPROJ
@@ -2932,7 +2981,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
29322981
if "language_model." in name:
29332982
name = name.replace("language_model.", "") # for InternVL
29342983
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
2935-
or name.startswith("vision_model") or name.startswith("audio_tower"):
2984+
or name.startswith("vision_model") or name.startswith("audio_tower") \
2985+
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
29362986
# skip vision and audio tensors
29372987
return []
29382988
yield from super().modify_tensors(data_torch, name, bid)
@@ -3109,7 +3159,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
31093159
yield from super().modify_tensors(data_torch, name, bid)
31103160

31113161

3112-
@ModelBase.register("Ernie4_5_ForCausalLM")
3162+
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
31133163
class Ernie4_5Model(TextModel):
31143164
model_arch = gguf.MODEL_ARCH.ERNIE4_5
31153165

@@ -3604,6 +3654,19 @@ def prepare_tensors(self):
36043654
class Qwen3Model(Qwen2Model):
36053655
model_arch = gguf.MODEL_ARCH.QWEN3
36063656

3657+
def __init__(self, *args, **kwargs):
3658+
super().__init__(*args, **kwargs)
3659+
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
3660+
self.origin_hf_arch = hparams.get('architectures', [None])[0]
3661+
3662+
def set_vocab(self):
3663+
# deal with intern-s1-mini
3664+
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
3665+
self._set_vocab_interns1()
3666+
return
3667+
3668+
super().set_vocab()
3669+
36073670

36083671
@ModelBase.register("Qwen3MoeForCausalLM")
36093672
class Qwen3MoeModel(Qwen2MoeModel):
@@ -3620,73 +3683,7 @@ def set_vocab(self):
36203683
self._set_vocab_interns1()
36213684
return
36223685

3623-
try:
3624-
self._set_vocab_sentencepiece()
3625-
except FileNotFoundError:
3626-
self._set_vocab_gpt2()
3627-
3628-
def _set_vocab_interns1(self):
3629-
tokens: list[str] = []
3630-
toktypes: list[int] = []
3631-
3632-
from transformers import AutoTokenizer
3633-
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
3634-
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
3635-
vocab_size = self.hparams.get("vocab_size", len(vocab))
3636-
assert max(vocab.values()) < vocab_size
3637-
3638-
tokpre = self.get_vocab_base_pre(tokenizer)
3639-
3640-
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
3641-
added_vocab = tokenizer.get_added_vocab()
3642-
3643-
added_tokens_decoder = tokenizer.added_tokens_decoder
3644-
3645-
for i in range(vocab_size):
3646-
if i not in reverse_vocab:
3647-
tokens.append(f"[PAD{i}]")
3648-
toktypes.append(gguf.TokenType.UNUSED)
3649-
else:
3650-
token: str = reverse_vocab[i]
3651-
if token in added_vocab:
3652-
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
3653-
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
3654-
if not added_tokens_decoder[i].normalized:
3655-
previous_token = token
3656-
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
3657-
if previous_token != token:
3658-
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
3659-
3660-
if added_tokens_decoder[i].special or self.does_token_look_special(token):
3661-
toktypes.append(gguf.TokenType.CONTROL)
3662-
else:
3663-
toktypes.append(gguf.TokenType.USER_DEFINED)
3664-
else:
3665-
toktypes.append(gguf.TokenType.NORMAL)
3666-
tokens.append(token)
3667-
3668-
self.gguf_writer.add_tokenizer_model("gpt2")
3669-
self.gguf_writer.add_tokenizer_pre(tokpre)
3670-
self.gguf_writer.add_token_list(tokens)
3671-
self.gguf_writer.add_token_types(toktypes)
3672-
3673-
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
3674-
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
3675-
additional_special_tokens = []
3676-
if special_tokens_map_file.is_file():
3677-
with open(special_tokens_map_file, encoding = 'utf-8') as f:
3678-
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
3679-
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
3680-
if tokenizer_cfg_file.is_file():
3681-
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
3682-
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
3683-
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
3684-
for token in additional_special_tokens:
3685-
if token in token2ids_map:
3686-
special_vocab._set_special_token(token, token2ids_map[token])
3687-
special_vocab._set_special_token('eos', 151645)
3688-
special_vocab._set_special_token("bos", 151643)
3689-
special_vocab.add_to_gguf(self.gguf_writer)
3686+
super().set_vocab()
36903687

36913688

36923689
@ModelBase.register("GPT2LMHeadModel")
@@ -5854,6 +5851,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
58545851
return [(self.map_tensor_name(name), data_torch)]
58555852

58565853

5854+
@ModelBase.register("SeedOssForCausalLM")
5855+
class SeedOssModel(TextModel):
5856+
model_arch = gguf.MODEL_ARCH.SEED_OSS
5857+
5858+
58575859
@ModelBase.register("Olmo2ForCausalLM")
58585860
class Olmo2Model(TextModel):
58595861
model_arch = gguf.MODEL_ARCH.OLMO2
@@ -6252,9 +6254,11 @@ def prepare_tensors(self):
62526254
raise ValueError(f"Unprocessed experts: {experts}")
62536255

62546256

6255-
@ModelBase.register("DeepseekV2ForCausalLM")
6256-
@ModelBase.register("DeepseekV3ForCausalLM")
6257-
@ModelBase.register("KimiVLForConditionalGeneration")
6257+
@ModelBase.register(
6258+
"DeepseekV2ForCausalLM",
6259+
"DeepseekV3ForCausalLM",
6260+
"KimiVLForConditionalGeneration",
6261+
)
62586262
class DeepseekV2Model(TextModel):
62596263
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
62606264

@@ -8554,6 +8558,43 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
85548558
return "mm.2.weight"
85558559
return super().map_tensor_name(name, try_suffixes)
85568560

8561+
8562+
@ModelBase.register("KimiVLForConditionalGeneration")
8563+
class KimiVLModel(MmprojModel):
8564+
def __init__(self, *args, **kwargs):
8565+
super().__init__(*args, **kwargs)
8566+
assert self.hparams_vision is not None
8567+
self.hparams_vision["image_size"] = 64 * 14 # for compatibility
8568+
8569+
def set_gguf_parameters(self):
8570+
super().set_gguf_parameters()
8571+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
8572+
self.gguf_writer.add_vision_use_gelu(True)
8573+
self.gguf_writer.add_vision_projector_scale_factor(2)
8574+
# eps is the same as pytorch's default value
8575+
assert self.hparams_vision is not None
8576+
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
8577+
8578+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8579+
del bid # unused
8580+
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
8581+
8582+
if is_vision_tensor:
8583+
if "pos_emb.weight" in name:
8584+
data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
8585+
elif "wqkv" in name:
8586+
split_dim = 0 if "weight" in name else -1
8587+
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
8588+
return [
8589+
(self.map_tensor_name(name.replace("wqkv", "wq")), wq),
8590+
(self.map_tensor_name(name.replace("wqkv", "wk")), wk),
8591+
(self.map_tensor_name(name.replace("wqkv", "wv")), wv)
8592+
]
8593+
8594+
return [(self.map_tensor_name(name), data_torch)]
8595+
8596+
return [] # skip other tensors
8597+
85578598
###### CONVERSION LOGIC ######
85588599

85598600

docs/multimodal/minicpmv4.0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model
66

77

88
### Build llama.cpp
9-
Readme modification time: 20250206
9+
Readme modification time: 20250731
1010

1111
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
1212

docs/multimodal/minicpmv4.5.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
## MiniCPM-V 4.5
2+
3+
### Prepare models and code
4+
5+
Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch model from huggingface to "MiniCPM-V-4_5" folder.
6+
7+
8+
### Build llama.cpp
9+
Readme modification time: 20250826
10+
11+
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
12+
13+
Clone llama.cpp:
14+
```bash
15+
git clone https://github.com/ggerganov/llama.cpp
16+
cd llama.cpp
17+
```
18+
19+
Build llama.cpp using `CMake`:
20+
```bash
21+
cmake -B build
22+
cmake --build build --config Release
23+
```
24+
25+
26+
### Usage of MiniCPM-V 4
27+
28+
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-4_5-gguf) by us)
29+
30+
```bash
31+
python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-V-4_5
32+
python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-4_5 --minicpmv-projector ../MiniCPM-V-4_5/minicpmv.projector --output-dir ../MiniCPM-V-4_5/ --minicpmv_version 6
33+
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_5/model
34+
35+
# quantize int4 version
36+
./build/bin/llama-quantize ../MiniCPM-V-4_5/model/ggml-model-f16.gguf ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
37+
```
38+
39+
40+
Inference on Linux or Mac
41+
```bash
42+
# run in single-turn mode
43+
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
44+
45+
# run in conversation mode
46+
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf
47+
```

0 commit comments

Comments
 (0)