Skip to content

Commit a00046b

Browse files
author
xusenlin
committed
Fix bug
1 parent c882289 commit a00046b

File tree

6 files changed

+67
-20
lines changed

6 files changed

+67
-20
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
## 📢 新闻
2222

23-
+ 【2023.11.09】 `dev` 分支已经支持 `openai=1.2.0` 版本
23+
+ 【2023.11.09】 目前分支需要安装 `openai>=1.2.3` 版本
2424

2525

2626
+ 【2023.11.03】 支持 `chatglm3``qwen` 模型的 `function call` 调用功能,同时支持流式和非流式模式, [工具使用示例](https://github.com/xusenlinzy/api-for-open-llm/tree/master/examples/chatglm3/tool_using.py), 网页 `demo` 已经集成到 [streamlit-demo](./streamlit-demo)

api/generation/chatglm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def generate_stream_chatglm(
8282
context_len=2048,
8383
stream_interval=2,
8484
):
85-
prompt = params["messages"]
85+
prompt = params["prompt"]
8686
temperature = float(params.get("temperature", 1.0))
8787
repetition_penalty = float(params.get("repetition_penalty", 1.0))
8888
top_p = float(params.get("top_p", 1.0))

api/routes/embedding.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
import base64
2+
13
import numpy as np
24
import tiktoken
35
from fastapi import APIRouter, Depends
4-
from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage
5-
from openai.types.embedding import Embedding
6+
from openai.types.create_embedding_response import Usage
67

78
from api.config import config
89
from api.models import EMBEDDED_MODEL
910
from api.routes.utils import check_api_key
10-
from api.utils.protocol import EmbeddingCreateParams
11+
from api.utils.protocol import EmbeddingCreateParams, Embedding, CreateEmbeddingResponse
1112

1213
embedding_router = APIRouter()
1314

@@ -19,16 +20,16 @@ async def create_embeddings(request: EmbeddingCreateParams, model_name: str = No
1920
if request.model is None:
2021
request.model = model_name
2122

22-
inputs = request.input
23-
if isinstance(inputs, str):
24-
inputs = [inputs]
25-
elif isinstance(inputs, list):
26-
if isinstance(inputs[0], int):
23+
request.input = request.input
24+
if isinstance(request.input, str):
25+
request.input = [request.input]
26+
elif isinstance(request.input, list):
27+
if isinstance(request.input[0], int):
2728
decoding = tiktoken.model.encoding_for_model(request.model)
28-
inputs = [decoding.decode(inputs)]
29-
elif isinstance(inputs[0], list):
29+
request.input = [decoding.decode(request.input)]
30+
elif isinstance(request.input[0], list):
3031
decoding = tiktoken.model.encoding_for_model(request.model)
31-
inputs = [decoding.decode(text) for text in inputs]
32+
request.input = [decoding.decode(text) for text in request.input]
3233

3334
# https://huggingface.co/BAAI/bge-large-zh
3435
if EMBEDDED_MODEL is not None:
@@ -38,12 +39,11 @@ async def create_embeddings(request: EmbeddingCreateParams, model_name: str = No
3839
instruction = "为这个句子生成表示以用于检索相关文章:"
3940
elif "en" in config.EMBEDDING_NAME.lower():
4041
instruction = "Represent this sentence for searching relevant passages: "
41-
inputs = [instruction + q for q in inputs]
42+
request.inputs = [instruction + q for q in request.input]
4243

4344
data, total_tokens = [], 0
4445
batches = [
45-
inputs[i: min(i + 1024, len(inputs))]
46-
for i in range(0, len(inputs), 1024)
46+
request.input[i: i + 1024] for i in range(0, len(request.input), 1024)
4747
]
4848
for num_batch, batch in enumerate(batches):
4949
token_num = sum([len(i) for i in batch])
@@ -54,10 +54,14 @@ async def create_embeddings(request: EmbeddingCreateParams, model_name: str = No
5454
zeros = np.zeros((bs, config.EMBEDDING_SIZE - dim))
5555
vecs = np.c_[vecs, zeros]
5656

57-
vecs = vecs.tolist()
57+
if request.encoding_format == "base64":
58+
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
59+
else:
60+
vecs = vecs.tolist()
61+
5862
for i, embed in enumerate(vecs):
5963
data.append(
60-
Embedding(index=num_batch * 1024 + i, embedding=embed, object="embedding")
64+
Embedding(index=num_batch * 1024 + i, object="embedding", embedding=embed)
6165
)
6266

6367
total_tokens += token_num

api/utils/protocol.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from enum import Enum
2-
from typing import Optional, Dict, List, Union, Literal
2+
from typing import Optional, Dict, List, Union, Literal, Any
33

44
from openai.types.chat import (
55
ChatCompletionMessageParam,
66
ChatCompletionToolChoiceOptionParam,
77
ChatCompletionToolParam,
88
)
99
from openai.types.chat.completion_create_params import FunctionCall, ResponseFormat
10+
from openai.types.create_embedding_response import Usage
1011
from pydantic import BaseModel
1112

1213

@@ -358,3 +359,32 @@ class EmbeddingCreateParams(BaseModel):
358359
and detect abuse.
359360
[Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
360361
"""
362+
363+
364+
class Embedding(BaseModel):
365+
embedding: Any
366+
"""The embedding vector, which is a list of floats.
367+
368+
The length of vector depends on the model as listed in the
369+
[embedding guide](https://platform.openai.com/docs/guides/embeddings).
370+
"""
371+
372+
index: int
373+
"""The index of the embedding in the list of embeddings."""
374+
375+
object: Literal["embedding"]
376+
"""The object type, which is always "embedding"."""
377+
378+
379+
class CreateEmbeddingResponse(BaseModel):
380+
data: List[Embedding]
381+
"""The list of embeddings generated by the model."""
382+
383+
model: str
384+
"""The name of the model used to generate the embedding."""
385+
386+
object: Literal["list"]
387+
"""The object type, which is always "list"."""
388+
389+
usage: Usage
390+
"""The usage information for the request."""

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
openai==1.2.0
1+
openai>=1.2.3
22
bitsandbytes
33
fastapi==0.95.1
44
typing-inspect==0.8.0

tests/langchain_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from langchain.chat_models import ChatOpenAI
2+
from langchain.embeddings import OpenAIEmbeddings
3+
from langchain.schema import HumanMessage
4+
5+
text = "你好"
6+
messages = [HumanMessage(content=text)]
7+
8+
llm = ChatOpenAI(openai_api_key="xxx", openai_api_base="http://192.168.20.59:7891/v1")
9+
10+
print(llm(messages))
11+
12+
embedding = OpenAIEmbeddings(openai_api_key="xxx", openai_api_base="http://192.168.20.59:7891/v1")
13+
print(embedding.embed_documents(["你好"]))

0 commit comments

Comments
 (0)