Skip to content

Commit 211f416

Browse files
chzphoenixcuihzcrazywoola
authored
feat:add wenxin rerank (langgenius#9431)
Co-authored-by: cuihz <cuihz@knowbox.cn> Co-authored-by: crazywoola <427733928@qq.com>
1 parent b90ad58 commit 211f416

File tree

6 files changed

+178
-0
lines changed

6 files changed

+178
-0
lines changed

api/core/model_runtime/model_providers/wenxin/_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class _CommonWenxin:
120120
"bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en",
121121
"bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh",
122122
"tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k",
123+
"bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base",
123124
}
124125

125126
function_calling_supports = [

api/core/model_runtime/model_providers/wenxin/rerank/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
model: bce-reranker-base_v1
2+
model_type: rerank
3+
model_properties:
4+
context_size: 4096
5+
pricing:
6+
input: '0.0005'
7+
unit: '0.001'
8+
currency: RMB
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from typing import Optional
2+
3+
import httpx
4+
5+
from core.model_runtime.entities.common_entities import I18nObject
6+
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
7+
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
8+
from core.model_runtime.errors.invoke import (
9+
InvokeAuthorizationError,
10+
InvokeBadRequestError,
11+
InvokeConnectionError,
12+
InvokeError,
13+
InvokeRateLimitError,
14+
InvokeServerUnavailableError,
15+
)
16+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
17+
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
18+
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
19+
20+
21+
class WenxinRerank(_CommonWenxin):
22+
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
23+
access_token = self._get_access_token()
24+
url = f"{self.api_bases[model]}?access_token={access_token}"
25+
26+
try:
27+
response = httpx.post(
28+
url,
29+
json={"model": model, "query": query, "documents": docs, "top_n": top_n},
30+
headers={"Content-Type": "application/json"},
31+
)
32+
response.raise_for_status()
33+
return response.json()
34+
except httpx.HTTPStatusError as e:
35+
raise InvokeServerUnavailableError(str(e))
36+
37+
38+
class WenxinRerankModel(RerankModel):
39+
"""
40+
Model class for wenxin rerank model.
41+
"""
42+
43+
def _invoke(
44+
self,
45+
model: str,
46+
credentials: dict,
47+
query: str,
48+
docs: list[str],
49+
score_threshold: Optional[float] = None,
50+
top_n: Optional[int] = None,
51+
user: Optional[str] = None,
52+
) -> RerankResult:
53+
"""
54+
Invoke rerank model
55+
56+
:param model: model name
57+
:param credentials: model credentials
58+
:param query: search query
59+
:param docs: docs for reranking
60+
:param score_threshold: score threshold
61+
:param top_n: top n documents to return
62+
:param user: unique user id
63+
:return: rerank result
64+
"""
65+
if len(docs) == 0:
66+
return RerankResult(model=model, docs=[])
67+
68+
api_key = credentials["api_key"]
69+
secret_key = credentials["secret_key"]
70+
71+
wenxin_rerank: WenxinRerank = WenxinRerank(api_key, secret_key)
72+
73+
try:
74+
results = wenxin_rerank.rerank(model, query, docs, top_n)
75+
76+
rerank_documents = []
77+
for result in results["results"]:
78+
index = result["index"]
79+
if "document" in result:
80+
text = result["document"]
81+
else:
82+
# llama.cpp rerank maynot return original documents
83+
text = docs[index]
84+
85+
rerank_document = RerankDocument(
86+
index=index,
87+
text=text,
88+
score=result["relevance_score"],
89+
)
90+
91+
if score_threshold is None or result["relevance_score"] >= score_threshold:
92+
rerank_documents.append(rerank_document)
93+
94+
return RerankResult(model=model, docs=rerank_documents)
95+
except httpx.HTTPStatusError as e:
96+
raise InvokeServerUnavailableError(str(e))
97+
98+
def validate_credentials(self, model: str, credentials: dict) -> None:
99+
"""
100+
Validate model credentials
101+
102+
:param model: model name
103+
:param credentials: model credentials
104+
:return:
105+
"""
106+
try:
107+
self._invoke(
108+
model=model,
109+
credentials=credentials,
110+
query="What is the capital of the United States?",
111+
docs=[
112+
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
113+
"Census, Carson City had a population of 55,274.",
114+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
115+
"are a political division controlled by the United States. Its capital is Saipan.",
116+
],
117+
score_threshold=0.8,
118+
)
119+
except Exception as ex:
120+
raise CredentialsValidateFailedError(str(ex))
121+
122+
@property
123+
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
124+
"""
125+
Map model invoke error to unified error
126+
"""
127+
return {
128+
InvokeConnectionError: [httpx.ConnectError],
129+
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
130+
InvokeRateLimitError: [],
131+
InvokeAuthorizationError: [httpx.HTTPStatusError],
132+
InvokeBadRequestError: [httpx.RequestError],
133+
}
134+
135+
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
136+
"""
137+
generate custom model entities from credentials
138+
"""
139+
entity = AIModelEntity(
140+
model=model,
141+
label=I18nObject(en_US=model),
142+
model_type=ModelType.RERANK,
143+
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
144+
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
145+
)
146+
147+
return entity

api/core/model_runtime/model_providers/wenxin/wenxin.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ help:
1818
supported_model_types:
1919
- llm
2020
- text-embedding
21+
- rerank
2122
configurate_methods:
2223
- predefined-model
2324
provider_credential_schema:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
from time import sleep
3+
4+
from core.model_runtime.entities.rerank_entities import RerankResult
5+
from core.model_runtime.model_providers.wenxin.rerank.rerank import WenxinRerankModel
6+
7+
8+
def test_invoke_bce_reranker_base_v1():
9+
sleep(3)
10+
model = WenxinRerankModel()
11+
12+
response = model.invoke(
13+
model="bce-reranker-base_v1",
14+
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
15+
query="What is Deep Learning?",
16+
docs=["Deep Learning is ...", "My Book is ..."],
17+
user="abc-123",
18+
)
19+
20+
assert isinstance(response, RerankResult)
21+
assert len(response.docs) == 2

0 commit comments

Comments
 (0)