Skip to content

Commit 7cf2f83

Browse files
author
xusenlin
committed
support for model glm-4v-9b
1 parent 9572eb9 commit 7cf2f83

File tree

11 files changed

+137
-25
lines changed

11 files changed

+137
-25
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
## 📢 新闻
2222

23-
+ 【2024.06.12】 重构项目代码
23+
24+
+ 【2024.06.12】 支持 `GLM-4V` 模型,修改环境变量 `MODEL_NAME=glm-4v` `PROMPT_NAME=glm-4v` `DTYPE=bfloat16`, 测试示例见 [glm4v](./tests/glm4v.py)
2425

2526

2627
+ 【2024.06.08】 已支持 `QWEN2` 模型,修改环境变量 `MODEL_NAME=qwen2` `PROMPT_NAME=qwen2`

api/engine/hf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,15 @@ def _generate(self, params: Dict[str, Any]) -> Iterator[dict]:
9191
"""
9292
prompt_or_messages = params.get("prompt_or_messages")
9393
if isinstance(prompt_or_messages, str):
94-
input_ids = self.tokenizer(prompt_or_messages).input_ids
94+
inputs = self.tokenizer(prompt_or_messages).input_ids
9595
else:
96-
input_ids = self.template.convert_messages_to_ids(
96+
print(prompt_or_messages)
97+
inputs = self.template.convert_messages_to_ids(
9798
prompt_or_messages,
9899
tools=params.get("tools"),
99100
max_tokens=params.get("max_tokens", 256),
100101
)
101-
params.update(dict(input_ids=input_ids))
102+
params.update(dict(inputs=inputs))
102103

103104
try:
104105
for output in self.generate_stream_func(self.model, self.tokenizer, params):

api/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class ErrorResponse(BaseModel):
6464

6565

6666
class ChatCompletionCreateParams(BaseModel):
67-
messages: List[ChatCompletionMessageParam]
67+
messages: List[Dict[str, Any]]
6868
"""A list of messages comprising the conversation so far.
6969
7070
[Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).

api/templates/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ChatGLM2ChatTemplate,
77
ChatGLM3ChatTemplate,
88
ChatGLM4ChatTemplate,
9+
GLM4VChatTemplate,
910
)
1011
from api.templates.qwen import QwenChatTemplate, Qwen2ChatTemplate
1112
from api.templates.registry import register_template, get_template
@@ -17,6 +18,7 @@
1718
"ChatGLM2ChatTemplate",
1819
"ChatGLM3ChatTemplate",
1920
"ChatGLM4ChatTemplate",
21+
"GLM4VChatTemplate",
2022
"QwenChatTemplate",
2123
"Qwen2ChatTemplate",
2224
"Llama2ChatTemplate",

api/templates/baichuan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Optional,
77
Dict,
88
Any,
9+
Union,
910
)
1011

1112
from openai.types.chat import ChatCompletionMessageParam
@@ -16,7 +17,7 @@
1617
from api.templates.utils import parse_messages
1718

1819
if TYPE_CHECKING:
19-
from transformers import PreTrainedTokenizer
20+
from transformers import PreTrainedTokenizer, BatchEncoding
2021

2122

2223
def build_baichuan_chat_input(
@@ -81,7 +82,7 @@ def _convert_messages_to_ids(
8182
max_tokens: Optional[int] = 256,
8283
max_window_size: Optional[int] = 6144,
8384
**kwargs,
84-
) -> List[int]:
85+
) -> Union[List[int], "BatchEncoding"]:
8586
return build_baichuan_chat_input(
8687
self.tokenizer,
8788
messages,

api/templates/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from openai.types.chat import ChatCompletionMessageParam
1616

1717
if TYPE_CHECKING:
18-
from transformers import PreTrainedTokenizer
18+
from transformers import PreTrainedTokenizer, BatchEncoding
1919

2020

2121
class ChatTemplate(ABC):
@@ -42,7 +42,7 @@ def convert_messages_to_ids(
4242
max_tokens: Optional[int] = 256,
4343
max_window_size: Optional[int] = 6144,
4444
**kwargs,
45-
) -> List[int]:
45+
) -> Union[List[int], "BatchEncoding"]:
4646
try:
4747
token_ids = self._convert_messages_to_ids(
4848
messages,
@@ -77,7 +77,7 @@ def _convert_messages_to_ids(
7777
max_tokens: Optional[int] = 256,
7878
max_window_size: Optional[int] = 6144,
7979
**kwargs,
80-
) -> List[int]:
80+
) -> Union[List[int], "BatchEncoding"]:
8181
raise NotImplementedError
8282

8383
def apply_chat_template(

api/templates/glm.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from api.templates.utils import apply_stopping_strings
2727

2828
if TYPE_CHECKING:
29-
from transformers import PreTrainedTokenizer, PreTrainedModel
29+
from transformers import PreTrainedTokenizer, PreTrainedModel, BatchEncoding
3030

3131

3232
class InvalidScoreLogitsProcessor(LogitsProcessor):
@@ -412,7 +412,7 @@ def _convert_messages_to_ids(
412412
max_tokens: Optional[int] = 256,
413413
max_window_size: Optional[int] = 6144,
414414
**kwargs,
415-
) -> List[int]:
415+
) -> Union[List[int], BatchEncoding]:
416416
messages = process_chatglm_messages(messages, tools)
417417
query, role = messages[-1]["content"], messages[-1]["role"]
418418
return self.tokenizer.build_chat_input(
@@ -489,7 +489,7 @@ def _convert_messages_to_ids(
489489
max_tokens: Optional[int] = 256,
490490
max_window_size: Optional[int] = 6144,
491491
**kwargs,
492-
) -> List[int]:
492+
) -> Union[List[int], BatchEncoding]:
493493
messages = process_chatglm_messages_v4(messages, tools)
494494
return self.tokenizer.apply_chat_template(
495495
messages,
@@ -534,3 +534,68 @@ def tool_call(**kwargs):
534534
"content": content
535535
}
536536
return output, content
537+
538+
539+
@register_template("glm-4v")
540+
class GLM4VChatTemplate(ChatTemplate):
541+
stop = ["<|endoftext|>", "<user>", "<|observation|>"]
542+
stop_token_ids = [151329, 151336, 151338]
543+
544+
def _convert_messages_to_ids(
545+
self,
546+
messages: List[ChatCompletionMessageParam],
547+
system: Optional[str] = None,
548+
tools: Optional[List[Dict[str, Any]]] = None,
549+
max_tokens: Optional[int] = 256,
550+
max_window_size: Optional[int] = 6144,
551+
**kwargs,
552+
) -> Union[List[int], "BatchEncoding"]:
553+
_messages = []
554+
for message in messages:
555+
if isinstance(message["content"], str):
556+
_content, image = message["content"], None
557+
else:
558+
_content, image = None, None
559+
for c in message["content"]:
560+
if isinstance(c, dict) and "type" in c:
561+
if c["type"] == "text":
562+
_content = c["text"]
563+
564+
if c["type"] == "image_url":
565+
if (
566+
isinstance(c["image_url"], dict)
567+
and "url" in c["image_url"]
568+
):
569+
image = self._load_image(image_url=c["image_url"]["url"])
570+
else:
571+
image = self._load_image(image_url=c["image_url"])
572+
573+
msg = {"role": message["role"], "content": _content}
574+
if image is not None:
575+
msg["image"] = image
576+
_messages.append(msg)
577+
578+
return self.tokenizer.apply_chat_template(
579+
_messages,
580+
add_generation_prompt=True,
581+
tokenize=True,
582+
return_tensors="pt",
583+
return_dict=True,
584+
)
585+
586+
@staticmethod
587+
def _load_image(image_url: str):
588+
from PIL import Image
589+
from io import BytesIO
590+
591+
if image_url.startswith("data:"):
592+
import base64
593+
594+
image_bytes = base64.b64decode(image_url.split(",")[1])
595+
else:
596+
import urllib.request
597+
598+
with urllib.request.urlopen(image_url) as f:
599+
image_bytes = f.read()
600+
601+
return Image.open(BytesIO(image_bytes)).convert("RGB")

api/templates/qwen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from api.templates.registry import register_template
2525

2626
if TYPE_CHECKING:
27-
from transformers import PreTrainedTokenizer
27+
from transformers import PreTrainedTokenizer, BatchEncoding
2828

2929

3030
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
@@ -277,7 +277,7 @@ def _convert_messages_to_ids(
277277
max_tokens: Optional[int] = 256,
278278
max_window_size: Optional[int] = 6144,
279279
**kwargs,
280-
) -> List[int]:
280+
) -> Union[List[int], BatchEncoding]:
281281
return build_qwen_chat_input(
282282
self.tokenizer,
283283
messages,

api/templates/stream.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import gc
34
import time
45
import uuid
56
from threading import Thread
@@ -25,7 +26,7 @@ def generate_stream(
2526
tokenizer: "PreTrainedTokenizer",
2627
params: Dict[str, Any],
2728
):
28-
input_ids = params.get("input_ids")
29+
inputs = params.get("inputs")
2930
functions = params.get("functions")
3031
model_name = params.get("model", "llm")
3132
temperature = float(params.get("temperature", 1.0))
@@ -39,10 +40,8 @@ def generate_stream(
3940
stop_token_ids.append(tokenizer.eos_token_id)
4041
stop_strings = params.get("stop", [])
4142

42-
input_echo_len = len(input_ids)
43-
device = model.device
43+
device = next(model.parameters()).device
4444
generation_kwargs = dict(
45-
input_ids=torch.tensor([input_ids], device=device),
4645
do_sample=True,
4746
temperature=temperature,
4847
top_p=top_p,
@@ -55,6 +54,14 @@ def generate_stream(
5554
generation_kwargs["do_sample"] = False
5655
generation_kwargs.pop("top_k")
5756

57+
if isinstance(inputs, dict):
58+
inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
59+
generation_kwargs.update(inputs)
60+
input_echo_len = len(inputs["input_ids"][0])
61+
else:
62+
generation_kwargs["input_ids"] = torch.tensor([inputs], device=device)
63+
input_echo_len = len(inputs)
64+
5865
streamer = TextIteratorStreamer(
5966
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
6067
)
@@ -114,3 +121,6 @@ def generate_stream(
114121
"total_tokens": input_echo_len + i,
115122
},
116123
}
124+
125+
gc.collect()
126+
torch.cuda.empty_cache()

docs/SCRIPT.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,13 @@ python server.py
108108

109109
### GLM系列
110110

111-
| 模型 | 环境变量示例 |
112-
|----------|----------------------------------------------------------------------------------------------------|
113-
| chatglm | `MODEL_NAME=chatglm``MODEL_PATH=THUDM/chatglm-6b``PROMPT_NAME=chatglm``DEVICE_MAP=cuda:0` |
114-
| chatglm2 | `MODEL_NAME=chatglm2``MODEL_PATH=THUDM/chatglm2-6b``PROMPT_NAME=chatglm2``DEVICE_MAP=cuda:0` |
115-
| chatglm3 | `MODEL_NAME=chatglm3``MODEL_PATH=THUDM/chatglm3-6b``PROMPT_NAME=chatglm3``DEVICE_MAP=cuda:0` |
116-
| glm4 | `MODEL_NAME=chatglm4``MODEL_PATH=THUDM/glm-4-9b-chat``PROMPT_NAME=chatglm4``DEVICE_MAP=cuda:0` |
111+
| 模型 | 环境变量示例 |
112+
|-----------|------------------------------------------------------------------------------------------------------------|
113+
| chatglm | `MODEL_NAME=chatglm``MODEL_PATH=THUDM/chatglm-6b``PROMPT_NAME=chatglm``DEVICE_MAP=cuda:0` |
114+
| chatglm2 | `MODEL_NAME=chatglm2``MODEL_PATH=THUDM/chatglm2-6b``PROMPT_NAME=chatglm2``DEVICE_MAP=cuda:0` |
115+
| chatglm3 | `MODEL_NAME=chatglm3``MODEL_PATH=THUDM/chatglm3-6b``PROMPT_NAME=chatglm3``DEVICE_MAP=cuda:0` |
116+
| glm4-chat | `MODEL_NAME=chatglm4``MODEL_PATH=THUDM/glm-4-9b-chat``PROMPT_NAME=chatglm4``DEVICE_MAP=cuda:0` |
117+
| glm-4v | `MODEL_NAME=glm-4v``MODEL_PATH=THUDM/glm-4v-9b``PROMPT_NAME=glm-4v``DEVICE_MAP=auto``DTYPE=bfloat16` |
117118

118119

119120
### BAICHUAN系列

tests/glm4v.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from openai import OpenAI
2+
3+
client = OpenAI(
4+
api_key="EMPTY",
5+
base_url="http://192.168.0.59:7891/v1/",
6+
)
7+
8+
stream = client.chat.completions.create(
9+
messages=[
10+
{
11+
"role": "user",
12+
"content": [
13+
{
14+
"type": "text",
15+
"text": "这张图片是什么地方?"
16+
},
17+
{
18+
"type": "image_url",
19+
"image_url": {
20+
# Either an url or a local path
21+
"url": "http://djclub.cdn.bcebos.com/uploads/images/pageimg/20230325/64-230325205T52.jpg"
22+
}
23+
}
24+
]
25+
}
26+
],
27+
model="glm-4v-9b",
28+
stream=True,
29+
)
30+
for part in stream:
31+
print(part.choices[0].delta.content or "", end="", flush=True)

0 commit comments

Comments
 (0)