Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xinference/core/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def predict(history, bot, max_tokens, temperature, stream):
if "content" not in delta:
continue
else:
response_content += delta["content"]
response_content += html.escape(delta["content"])
bot[-1][1] = response_content
yield history, bot
history.append(
Expand Down
94 changes: 94 additions & 0 deletions xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -18576,5 +18576,99 @@
"#system_numpy#"
]
}
},
{
"version": 2,
"context_length": 65536,
"model_name": "glm-4.1v-thinking",
"model_lang": [
"en",
"zh"
],
"model_ability": [
"chat",
"vision",
"reasoning"
],
"model_description": "GLM-4.1V-9B-Thinking, designed to explore the upper limits of reasoning in vision-language models.",
"model_specs": [
{
"model_format": "pytorch",
"model_size_in_billions": 9,
"model_src": {
"huggingface": {
"quantizations": [
"none"
],
"model_id": "THUDM/GLM-4.1V-9B-Thinking",
"model_revision": "b627c82cd8fc9175ff2b82b33fb439eba260055f"
},
"modelscope": {
"quantizations": [
"none"
],
"model_id": "ZhipuAI/GLM-4.1V-9B-Thinking",
"model_revision": "master"
}
}
},
{
"model_format": "awq",
"model_size_in_billions": 9,
"model_src": {
"huggingface": {
"quantizations": [
"Int4"
],
"model_id": "dengcao/GLM-4.1V-9B-Thinking-AWQ"
},
"modelscope": {
"quantizations": [
"Int4"
],
"model_id": "dengcao/GLM-4.1V-9B-Thinking-AWQ",
"model_revision": "master"
}
}
},
{
"model_format": "gptq",
"model_size_in_billions": 9,
"model_src": {
"huggingface": {
"quantizations": [
"Int4-Int8Mix"
],
"model_id": "dengcao/GLM-4.1V-9B-Thinking-GPTQ-Int4-Int8Mix"
},
"modelscope": {
"quantizations": [
"Int4-Int8Mix"
],
"model_id": "dengcao/GLM-4.1V-9B-Thinking-GPTQ-Int4-Int8Mix",
"model_revision": "master"
}
}
}
],
"chat_template": "[gMASK]<sop> {%- for msg in messages %} {%- if msg.role == 'system' %} <|system|> {{ msg.content }} {%- elif msg.role == 'user' %} <|user|>{{ '\n' }} {%- if msg.content is string %} {{ msg.content }} {%- else %} {%- for item in msg.content %} {%- if item.type == 'video' or 'video' in item %} <|begin_of_video|><|video|><|end_of_video|> {%- elif item.type == 'image' or 'image' in item %} <|begin_of_image|><|image|><|end_of_image|> {%- elif item.type == 'text' %} {{ item.text }} {%- endif %} {%- endfor %} {%- endif %} {%- elif msg.role == 'assistant' %} {%- if msg.metadata %} <|assistant|>{{ msg.metadata }} {{ msg.content }} {%- else %} <|assistant|> {{ msg.content }} {%- endif %} {%- endif %} {%- endfor %} {% if add_generation_prompt %}<|assistant|> {% endif %}",
"stop_token_ids": [
151329,
151336,
151338
],
"stop": [
"<|endoftext|>",
"<|user|>",
"<|observation|>"
],
"reasoning_start_tag": "<think>",
"reasoning_end_tag": "</think>",
"virtualenv": {
"packages": [
"transformers>=4.53.2",
"#system_numpy#"
]
}
}
]
167 changes: 167 additions & 0 deletions xinference/model/llm/transformers/multimodal/glm4_1v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2022-2025 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from typing import Any, Dict, Iterator, List, Tuple

import torch

from .....model.utils import select_device
from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
from ...utils import _decode_image
from ..core import register_non_default_model
from .core import PytorchMultiModalModel

logger = logging.getLogger(__name__)


@register_transformer
@register_non_default_model("glm-4.1v-thinking")
class Glm4_1VModel(PytorchMultiModalModel):
@classmethod
def match_json(
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
) -> bool:
family = model_family.model_family or model_family.model_name
if "glm-4.1v" in family.lower():
return True
return False

def decide_device(self):
device = self._pytorch_model_config.get("device", "auto")
self._device = select_device(device)

def load_processor(self):
from transformers import AutoProcessor

self._processor = AutoProcessor.from_pretrained(self.model_path, use_fast=True)
self._tokenizer = self._processor.tokenizer

def load_multimodal_model(self):
from transformers import Glm4vForConditionalGeneration

kwargs = {"device_map": "auto"}
kwargs = self.apply_bnb_quantization(kwargs)

model = Glm4vForConditionalGeneration.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
**kwargs,
)
self._model = model.eval()
self._device = self._model.device

@staticmethod
def _get_processed_msgs(messages: List[Dict]) -> List[Dict]:
res = []
for message in messages:
role = message["role"]
content = message["content"]
if isinstance(content, str):
res.append({"role": role, "content": content})
else:
texts = []
image_urls = []
for c in content:
c_type = c.get("type")
if c_type == "text":
texts.append(c["text"])
else:
assert (
c_type == "image_url"
), "Please follow the image input of the OpenAI API."
image_urls.append(c["image_url"]["url"])
if len(image_urls) > 1:
raise RuntimeError("Only one image per message is supported")
image_futures = []
with ThreadPoolExecutor() as executor:
for image_url in image_urls:
fut = executor.submit(_decode_image, image_url)
image_futures.append(fut)
images = [fut.result() for fut in image_futures]
assert len(images) <= 1
text = " ".join(texts)
if images:
content = [
{"type": "image", "image": images[0]},
{"type": "text", "text": text},
]
res.append({"role": role, "content": content})
else:
res.append(
{"role": role, "content": {"type": "text", "text": text}}
)
return res

def build_inputs_from_messages(
self,
messages: List[Dict],
generate_config: Dict,
):
msgs = self._get_processed_msgs(messages)
inputs = self._processor.apply_chat_template(
msgs,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
) # chat mode
inputs = inputs.to(self._model.device)
return inputs

def get_stop_strs(self) -> List[str]:
return ["<|endoftext|>"]

def get_builtin_stop_token_ids(self) -> Tuple:
from transformers import AutoConfig

return tuple(AutoConfig.from_pretrained(self.model_path).eos_token_id)

def build_generate_kwargs(
self,
generate_config: Dict,
) -> Dict[str, Any]:
return dict(
do_sample=True,
top_p=generate_config.get("top_p", 1e-5),
repetition_penalty=generate_config.get("repetition_penalty", 1.1),
top_k=generate_config.get("top_k", 2),
max_new_tokens=generate_config.get("max_tokens", 512),
)

def build_streaming_iter(
self,
messages: List[Dict],
generate_config: Dict,
) -> Tuple[Iterator, int]:
from transformers import TextIteratorStreamer

generate_kwargs = self.build_generate_kwargs(generate_config)
inputs = self.build_inputs_from_messages(messages, generate_config)
streamer = TextIteratorStreamer(
tokenizer=self._tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=False,
)
kwargs = {
**inputs,
**generate_kwargs,
"streamer": streamer,
}
logger.debug("Generate with kwargs: %s", generate_kwargs)
t = Thread(target=self._model.generate, kwargs=kwargs)
t.start()
return streamer, len(inputs.input_ids[0])
1 change: 1 addition & 0 deletions xinference/model/llm/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ class VLLMGenerateConfig(TypedDict, total=False):

if VLLM_INSTALLED and vllm.__version__ >= "0.9.2":
VLLM_SUPPORTED_CHAT_MODELS.append("Ernie4.5")
VLLM_SUPPORTED_VISION_MODEL_LIST.append("glm-4.1v-thinking")


class VLLMModel(LLM):
Expand Down
Loading