2
2
import logging
3
3
import argparse
4
4
import requests
5
+
6
+ import httpx
7
+
5
8
from kserve import Model , ModelServer , model_server
6
9
10
+
7
11
logger = logging .getLogger (__name__ )
8
12
13
+ PREDICTOR_URL_FORMAT = "http://{0}/v1/models/vi/chat/completions"
9
14
10
15
class Transformer (Model ):
11
16
def __init__ (self , name : str , predictor_host : str , protocol : str ,
12
17
use_ssl : bool , vectorstore_name : str = "vectorstore" ):
13
18
super ().__init__ (name )
19
+ # KServe specific arguments
14
20
self .name = name
15
21
self .predictor_host = predictor_host
16
22
self .protocol = protocol
@@ -22,7 +28,9 @@ def __init__(self, name: str, predictor_host: str, protocol: str,
22
28
self .vectorstore_url = self ._build_vectorstore_url ()
23
29
24
30
def _get_namespace (self ):
25
- return open ("/var/run/secrets/kubernetes.io/serviceaccount/namespace" , "r" ).read ().strip ()
31
+ return (open (
32
+ "/var/run/secrets/kubernetes.io/serviceaccount/namespace" , "r" )
33
+ .read ())
26
34
27
35
def _build_vectorstore_url (self ):
28
36
domain_name = "svc.cluster.local"
@@ -34,58 +42,35 @@ def _build_vectorstore_url(self):
34
42
url = f"http://{ svc } /v1/models/{ model_name } :predict"
35
43
return url
36
44
45
+ @property
46
+ def _http_client (self ):
47
+ if self ._http_client_instance is None :
48
+ # No Authorization header needed
49
+ self ._http_client_instance = httpx .AsyncClient (verify = False ) # Removed headers argument
50
+ return self ._http_client_instance
51
+
37
52
def preprocess (self , request : dict , headers : dict ) -> dict :
38
53
data = request ["instances" ][0 ]
39
54
query = data ["input" ]
40
- num_docs = data .get ("num_docs" , 4 )
41
- system_message = data .get ("system" , "You are an AI assistant." )
42
- instruction = data .get ("instruction" , "Answer the question using the context below." )
43
-
44
55
logger .info (f"Received question: { query } " )
56
+ num_docs = data .get ("num_docs" , 4 )
45
57
context = data .get ("context" , None )
46
-
47
- # If no context is provided, retrieve documents from the vector store
48
- if not context :
49
- payload = {"instances" : [{"input" : query , "num_docs" : num_docs }]}
50
- logger .info (f"Retrieving relevant documents from: { self .vectorstore_url } " )
51
-
52
- response = requests .post (self .vectorstore_url , json = payload , verify = False )
53
- response_data = response .json ()
54
-
55
- if response .status_code == 200 and "predictions" in response_data :
56
- context = "\n " .join (response_data ["predictions" ])
57
- else :
58
- context = "No relevant documents found."
59
-
60
- logger .info (f"Retrieved Context:\n { context } " )
61
-
62
- # 🔥 FIX: Ensure correct predictor URL
63
- predictor_url = f"http://{ self .predictor_host } .svc.cluster.local/v1/chat/completions"
64
- logger .info (f"Sending request to LLM predictor at { predictor_url } " )
65
-
66
- llm_payload = {
67
- "model" : "meta/llama-2-7b-chat" ,
68
- "messages" : [
69
- {"role" : "system" , "content" : system_message },
70
- {"role" : "user" , "content" : f"{ instruction } \n \n Context: { context } \n \n Question: { query } " }
71
- ],
72
- "temperature" : data .get ("temperature" , 0.5 ),
73
- "top_p" : data .get ("top_p" , 1 ),
74
- "max_tokens" : int (data .get ("max_tokens" , 256 )),
75
- "stream" : False
76
- }
77
-
78
- llm_response = requests .post (predictor_url , json = llm_payload , verify = False )
79
-
80
- if llm_response .status_code == 200 :
81
- result = llm_response .json ()["choices" ][0 ]["message" ]["content" ]
82
- logger .info (f"LLM Response: { result } " )
83
- return {"predictions" : [result ]}
58
+ if context :
59
+ logger .info (f"Received context: { context } " )
60
+ logger .info (f"Skipping retrieval step..." )
61
+ return {"instances" : [data ]}
84
62
else :
85
- error_message = f"Error calling LLM predictor: { llm_response .status_code } - { llm_response .text } "
86
- logger .error (error_message )
87
- return {"predictions" : [error_message ]}
88
-
63
+ payload = {"instances" :[{"input" : query , "num_docs" : num_docs }]}
64
+ logger .info (
65
+ f"Receiving relevant docs from: { self .vectorstore_url } " )
66
+
67
+ response = requests .post (
68
+ self .vectorstore_url , json = payload ,
69
+ verify = False )
70
+ response = json .loads (response .text )
71
+ context = "\n " .join (response ["predictions" ])
72
+ logger .info (f"Received documents: { context } " )
73
+ return {"instances" : [{** data , ** {"context" : context }}]}
89
74
90
75
if __name__ == "__main__" :
91
76
parser = argparse .ArgumentParser (parents = [model_server .parser ])
@@ -97,7 +82,7 @@ def preprocess(self, request: dict, headers: dict) -> dict:
97
82
parser .add_argument (
98
83
"--model_name" , help = "The name that the model is served under." )
99
84
parser .add_argument (
100
- "--use_ssl" , help = "Use SSL for connecting to the predictor" ,
85
+ "--use_ssl" , help = "Use ssl for connecting to the predictor" ,
101
86
action = 'store_true' )
102
87
parser .add_argument ("--vectorstore_name" , default = "vectorstore" ,
103
88
required = False ,
@@ -109,4 +94,4 @@ def preprocess(self, request: dict, headers: dict) -> dict:
109
94
model = Transformer (
110
95
args .model_name , args .predictor_host , args .protocol , args .use_ssl ,
111
96
args .vectorstore_name )
112
- ModelServer ().start ([model ])
97
+ ModelServer ().start ([model ])
0 commit comments