Skip to content

Commit b74fdba

Browse files
authored
Merge pull request xusenlinzy#164 from xusenlinzy/dev
Dev
2 parents 5473edf + c592d9d commit b74fdba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+880
-1946
lines changed

README.md

Lines changed: 51 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
## 📢 新闻
2222

23-
+ 【2023.11.08`dev` 分支已经支持 `openai=1.1.0` 版本
23+
+ 【2023.11.09`dev` 分支已经支持 `openai=1.2.0` 版本
2424

2525

2626
+ 【2023.11.03】 支持 `chatglm3``qwen` 模型的 `function call` 调用功能,同时支持流式和非流式模式, [工具使用示例](https://github.com/xusenlinzy/api-for-open-llm/tree/master/examples/chatglm3/tool_using.py), 网页 `demo` 已经集成到 [streamlit-demo](./streamlit-demo)
@@ -148,30 +148,45 @@ streamlit run streamlit_app.py
148148

149149
![img.png](images/demo.png)
150150

151-
### [openai](https://github.com/openai/openai-python)
151+
### [openai v1.1.0](https://github.com/openai/openai-python)
152152

153153
<details>
154154
<summary>👉 Chat Completions</summary>
155155

156156
```python
157-
import openai
157+
from openai import OpenAI
158158

159-
openai.api_base = "http://192.168.0.xx:80/v1"
160-
161-
# Enter any non-empty API key to pass the client library's check.
162-
openai.api_key = "xxx"
159+
client = OpenAI(
160+
api_key="EMPTY",
161+
base_url="http://192.168.20.59:7891/v1/",
162+
)
163163

164-
# Enter any non-empty model name to pass the client library's check.
165-
completion = openai.ChatCompletion.create(
166-
model="chatglm-6b",
164+
# Chat completion API
165+
chat_completion = client.chat.completions.create(
167166
messages=[
168-
{"role": "user", "content": "你好"},
167+
{
168+
"role": "user",
169+
"content": "你好",
170+
}
169171
],
170-
stream=False,
172+
model="gpt-3.5-turbo",
171173
)
172-
173-
print(completion.choices[0].message.content)
174-
# 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
174+
print(chat_completion)
175+
# 你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。
176+
177+
178+
# stream = client.chat.completions.create(
179+
# messages=[
180+
# {
181+
# "role": "user",
182+
# "content": "感冒了怎么办",
183+
# }
184+
# ],
185+
# model="gpt-3.5-turbo",
186+
# stream=True,
187+
# )
188+
# for part in stream:
189+
# print(part.choices[0].delta.content or "", end="", flush=True)
175190
```
176191

177192
</details>
@@ -180,17 +195,20 @@ print(completion.choices[0].message.content)
180195
<summary>👉 Completions</summary>
181196

182197
```python
183-
import openai
184-
185-
openai.api_base = "http://192.168.0.xx:80/v1"
198+
from openai import OpenAI
186199

187-
# Enter any non-empty API key to pass the client library's check.
188-
openai.api_key = "xxx"
200+
client = OpenAI(
201+
api_key="EMPTY",
202+
base_url="http://192.168.20.59:7891/v1/",
203+
)
189204

190-
# Enter any non-empty model name to pass the client library's check.
191-
completion = openai.Completion.create(prompt="你好", model="chatglm-6b")
192205

193-
print(completion.choices[0].text)
206+
# Chat completion API
207+
completion = client.completions.create(
208+
model="gpt-3.5-turbo",
209+
prompt="你好",
210+
)
211+
print(completion)
194212
# 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
195213
```
196214

@@ -200,82 +218,29 @@ print(completion.choices[0].text)
200218
<summary>👉 Embeddings</summary>
201219

202220
```python
203-
import openai
221+
from openai import OpenAI
204222

205-
openai.api_base = "http://192.168.0.xx:80/v1"
223+
client = OpenAI(
224+
api_key="EMPTY",
225+
base_url="http://192.168.20.59:7891/v1/",
226+
)
206227

207-
# Enter any non-empty API key to pass the client library's check.
208-
openai.api_key = "xxx"
209228

210229
# compute the embedding of the text
211-
embedding = openai.Embedding.create(
212-
input="什么是chatgpt?",
213-
model="text2vec-large-chinese"
230+
embedding = client.embeddings.create(
231+
input="你好",
232+
model="text-embedding-ada-002"
214233
)
234+
print(embedding)
215235

216-
print(embedding['data'][0]['embedding'])
217-
```
218-
219-
</details>
220-
221-
### [langchain](https://github.com/hwchase17/langchain)
222-
223-
<details>
224-
<summary>👉 Chat Completions</summary>
225-
226-
```python
227-
import os
228-
229-
os.environ["OPENAI_API_BASE"] = "http://192.168.0.xx:80/v1"
230-
os.environ["OPENAI_API_KEY"] = "xxx"
231-
232-
from langchain.chat_models import ChatOpenAI
233-
from langchain.schema import HumanMessage
234-
235-
chat = ChatOpenAI()
236-
print(chat([HumanMessage(content="你好")]))
237-
# content='你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。' additional_kwargs={}
238-
```
239-
</details>
240-
241-
<details>
242-
<summary>👉 Completions</summary>
243-
244-
```python
245-
import os
246-
247-
os.environ["OPENAI_API_BASE"] = "http://192.168.0.xx:80/v1"
248-
os.environ["OPENAI_API_KEY"] = "xxx"
249-
250-
from langchain.llms import OpenAI
251-
252-
llm = OpenAI()
253-
print(llm("你好"))
254-
# 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
255236
```
256237

257238
</details>
258239

259-
<details>
260-
<summary>👉 Embeddings</summary>
261-
262-
```python
263-
import os
264-
265-
os.environ["OPENAI_API_BASE"] = "http://192.168.0.xx:80/v1"
266-
os.environ["OPENAI_API_KEY"] = "xxx"
267-
268-
from langchain.embeddings import OpenAIEmbeddings
269-
270-
embeddings = OpenAIEmbeddings()
271-
query_result = embeddings.embed_query("什么是chatgpt?")
272-
print(query_result)
273-
```
274-
</details>
275240

276241
### 可接入的项目
277242

278-
**通过修改上面的 `OPENAI_API_BASE` 环境变量,大部分的 `chatgpt` 应用和前后端项目都可以无缝衔接!**
243+
**通过修改 `OPENAI_API_BASE` 环境变量,大部分的 `chatgpt` 应用和前后端项目都可以无缝衔接!**
279244

280245
+ [ChatGPT-Next-Web: One-Click to deploy well-designed ChatGPT web UI on Vercel](https://github.com/Yidadaa/ChatGPT-Next-Web)
281246

api/apapter/template.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import lru_cache
2-
from typing import List, Optional, Union, Dict
2+
from typing import List, Optional, Dict
33

4-
from api.utils.protocol import ChatMessage
4+
from openai.types.chat import ChatCompletionMessageParam
55

66

77
@lru_cache
@@ -32,14 +32,14 @@ def match(self, name) -> bool:
3232

3333
def apply_chat_template(
3434
self,
35-
conversation: List[Union[Dict[str, str], ChatMessage]],
35+
conversation: List[ChatCompletionMessageParam],
3636
add_generation_prompt: bool = True,
3737
) -> str:
3838
"""
3939
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a prompt.
4040
4141
Args:
42-
conversation (List[Union[Dict[str, str], ChatMessage]]): A Conversation object or list of dicts
42+
conversation (List[ChatCompletionMessageParam]): A Conversation object or list of dicts
4343
with "role" and "content" keys, representing the chat history so far.
4444
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
4545
the start of an assistant message. This is useful when you want to generate a response from the model.
@@ -49,10 +49,6 @@ def apply_chat_template(
4949
Returns:
5050
`str`: A prompt, which is ready to pass to the tokenizer.
5151
"""
52-
53-
if isinstance(conversation[0], ChatMessage):
54-
conversation = [c.dict(exclude_none=True) for c in conversation]
55-
5652
# Compilation function uses a cache to avoid recompiling the same template
5753
compiled_template = _compile_jinja_template(self.template)
5854

api/generation/baichuan.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from typing import List
22

3-
from transformers import PreTrainedTokenizer
3+
from openai.types.chat import ChatCompletionMessageParam
44

55
from api.generation.utils import parse_messages
6-
from api.utils.protocol import Role, ChatMessage
6+
from api.utils.protocol import Role
7+
from transformers import PreTrainedTokenizer
78

89

910
def build_baichuan_chat_input(
1011
tokenizer: PreTrainedTokenizer,
11-
messages: List[ChatMessage],
12+
messages: List[ChatCompletionMessageParam],
1213
context_len: int = 4096,
1314
max_new_tokens: int = 256
1415
) -> List[int]:
@@ -22,11 +23,11 @@ def build_baichuan_chat_input(
2223
for r in rounds[::-1]:
2324
round_tokens = []
2425
for message in r:
25-
if message.role == Role.USER:
26+
if message["role"] == Role.USER:
2627
round_tokens.append(195)
2728
else:
2829
round_tokens.append(196)
29-
round_tokens.extend(tokenizer.encode(message.content))
30+
round_tokens.extend(tokenizer.encode(message["content"]))
3031

3132
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
3233
history_tokens = round_tokens + history_tokens # concat left
@@ -35,7 +36,7 @@ def build_baichuan_chat_input(
3536
break
3637

3738
input_tokens = system_tokens + history_tokens
38-
if messages[-1].role != Role.ASSISTANT:
39+
if messages[-1]["role"] != Role.ASSISTANT:
3940
input_tokens.append(196)
4041

4142
return input_tokens[-max_input_tokens:] # truncate left

api/generation/chatglm.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
import torch
77
from loguru import logger
8+
from openai.types.chat import ChatCompletionMessageParam
89
from transformers.generation.logits_process import LogitsProcessor
910

1011
from api.generation.utils import apply_stopping_strings
11-
from api.utils.protocol import Role, ChatMessage
12+
from api.utils.protocol import Role
1213

1314

1415
class InvalidScoreLogitsProcessor(LogitsProcessor):
@@ -77,7 +78,7 @@ def generate_stream_chatglm(
7778
context_len=2048,
7879
stream_interval=2,
7980
):
80-
prompt = params["prompt"]
81+
prompt = params["messages"]
8182
temperature = float(params.get("temperature", 1.0))
8283
repetition_penalty = float(params.get("repetition_penalty", 1.0))
8384
top_p = float(params.get("top_p", 1.0))
@@ -147,8 +148,8 @@ def generate_stream_chatglm_v3(
147148
context_len=2048,
148149
stream_interval=2,
149150
):
150-
prompt: List[ChatMessage] = params["prompt"]
151-
functions = params["functions"]
151+
prompt: List[ChatCompletionMessageParam] = params["prompt"]
152+
functions = params.get("functions", None)
152153
temperature = float(params.get("temperature", 1.0))
153154
repetition_penalty = float(params.get("repetition_penalty", 1.0))
154155
top_p = float(params.get("top_p", 1.0))
@@ -225,7 +226,7 @@ def generate_stream_chatglm_v3(
225226
torch.cuda.empty_cache()
226227

227228

228-
def process_chatglm_messages(messages: List[ChatMessage], functions: Union[dict, List[dict]] = None) -> List[dict]:
229+
def process_chatglm_messages(messages: List[ChatCompletionMessageParam], functions: Union[dict, List[dict]] = None) -> List[dict]:
229230
_messages = messages
230231
messages = []
231232

@@ -239,10 +240,10 @@ def process_chatglm_messages(messages: List[ChatMessage], functions: Union[dict,
239240
)
240241

241242
for m in _messages:
242-
role, content, func_call = m.role, m.content, m.function_call
243+
role, content = m["role"], m["content"]
244+
func_call = m.get("function_call", None)
243245
if role == Role.FUNCTION:
244246
messages.append({"role": "observation", "content": content})
245-
246247
elif role == Role.ASSISTANT and func_call is not None:
247248
for response in content.split("<|assistant|>"):
248249
metadata, sub_content = response.split("\n", maxsplit=1)

0 commit comments

Comments
 (0)