Skip to content

Commit 81b24f7

Browse files
committed
fixing the completions error
1 parent 447ea1c commit 81b24f7

File tree

5 files changed

+74
-102
lines changed

5 files changed

+74
-102
lines changed

dockerfiles/transformers/llm_transformer_nvidia/src/.ipynb_checkpoints/model-checkpoint.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22
import logging
33
import argparse
44
import requests
5+
6+
import httpx
7+
58
from kserve import Model, ModelServer, model_server
69

10+
711
logger = logging.getLogger(__name__)
812

13+
PREDICTOR_URL_FORMAT = "http://{0}/v1/models/vi/chat/completions"
914

1015
class Transformer(Model):
1116
def __init__(self, name: str, predictor_host: str, protocol: str,
1217
use_ssl: bool, vectorstore_name: str = "vectorstore"):
1318
super().__init__(name)
19+
# KServe specific arguments
1420
self.name = name
1521
self.predictor_host = predictor_host
1622
self.protocol = protocol
@@ -22,7 +28,9 @@ def __init__(self, name: str, predictor_host: str, protocol: str,
2228
self.vectorstore_url = self._build_vectorstore_url()
2329

2430
def _get_namespace(self):
25-
return open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r").read().strip()
31+
return (open(
32+
"/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r")
33+
.read())
2634

2735
def _build_vectorstore_url(self):
2836
domain_name = "svc.cluster.local"
@@ -34,58 +42,35 @@ def _build_vectorstore_url(self):
3442
url = f"http://{svc}/v1/models/{model_name}:predict"
3543
return url
3644

45+
@property
46+
def _http_client(self):
47+
if self._http_client_instance is None:
48+
# No Authorization header needed
49+
self._http_client_instance = httpx.AsyncClient(verify=False) # Removed headers argument
50+
return self._http_client_instance
51+
3752
def preprocess(self, request: dict, headers: dict) -> dict:
3853
data = request["instances"][0]
3954
query = data["input"]
40-
num_docs = data.get("num_docs", 4)
41-
system_message = data.get("system", "You are an AI assistant.")
42-
instruction = data.get("instruction", "Answer the question using the context below.")
43-
4455
logger.info(f"Received question: {query}")
56+
num_docs = data.get("num_docs", 4)
4557
context = data.get("context", None)
46-
47-
# If no context is provided, retrieve documents from the vector store
48-
if not context:
49-
payload = {"instances": [{"input": query, "num_docs": num_docs}]}
50-
logger.info(f"Retrieving relevant documents from: {self.vectorstore_url}")
51-
52-
response = requests.post(self.vectorstore_url, json=payload, verify=False)
53-
response_data = response.json()
54-
55-
if response.status_code == 200 and "predictions" in response_data:
56-
context = "\n".join(response_data["predictions"])
57-
else:
58-
context = "No relevant documents found."
59-
60-
logger.info(f"Retrieved Context:\n{context}")
61-
62-
# 🔥 FIX: Ensure correct predictor URL
63-
predictor_url = f"http://{self.predictor_host}.svc.cluster.local/v1/chat/completions"
64-
logger.info(f"Sending request to LLM predictor at {predictor_url}")
65-
66-
llm_payload = {
67-
"model": "meta/llama-2-7b-chat",
68-
"messages": [
69-
{"role": "system", "content": system_message},
70-
{"role": "user", "content": f"{instruction}\n\nContext: {context}\n\nQuestion: {query}"}
71-
],
72-
"temperature": data.get("temperature", 0.5),
73-
"top_p": data.get("top_p", 1),
74-
"max_tokens": int(data.get("max_tokens", 256)),
75-
"stream": False
76-
}
77-
78-
llm_response = requests.post(predictor_url, json=llm_payload, verify=False)
79-
80-
if llm_response.status_code == 200:
81-
result = llm_response.json()["choices"][0]["message"]["content"]
82-
logger.info(f"LLM Response: {result}")
83-
return {"predictions": [result]}
58+
if context:
59+
logger.info(f"Received context: {context}")
60+
logger.info(f"Skipping retrieval step...")
61+
return {"instances": [data]}
8462
else:
85-
error_message = f"Error calling LLM predictor: {llm_response.status_code} - {llm_response.text}"
86-
logger.error(error_message)
87-
return {"predictions": [error_message]}
88-
63+
payload = {"instances":[{"input": query, "num_docs": num_docs}]}
64+
logger.info(
65+
f"Receiving relevant docs from: {self.vectorstore_url}")
66+
67+
response = requests.post(
68+
self.vectorstore_url, json=payload,
69+
verify=False)
70+
response = json.loads(response.text)
71+
context = "\n".join(response["predictions"])
72+
logger.info(f"Received documents: {context}")
73+
return {"instances": [{**data, **{"context": context}}]}
8974

9075
if __name__ == "__main__":
9176
parser = argparse.ArgumentParser(parents=[model_server.parser])
@@ -97,7 +82,7 @@ def preprocess(self, request: dict, headers: dict) -> dict:
9782
parser.add_argument(
9883
"--model_name", help="The name that the model is served under.")
9984
parser.add_argument(
100-
"--use_ssl", help="Use SSL for connecting to the predictor",
85+
"--use_ssl", help="Use ssl for connecting to the predictor",
10186
action='store_true')
10287
parser.add_argument("--vectorstore_name", default="vectorstore",
10388
required=False,
@@ -109,4 +94,4 @@ def preprocess(self, request: dict, headers: dict) -> dict:
10994
model = Transformer(
11095
args.model_name, args.predictor_host, args.protocol, args.use_ssl,
11196
args.vectorstore_name)
112-
ModelServer().start([model])
97+
ModelServer().start([model])

dockerfiles/transformers/llm_transformer_nvidia/src/model.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22
import logging
33
import argparse
44
import requests
5+
6+
import httpx
7+
58
from kserve import Model, ModelServer, model_server
69

10+
711
logger = logging.getLogger(__name__)
812

13+
PREDICTOR_URL_FORMAT = "http://{0}/v1/models/vi/chat/completions"
914

1015
class Transformer(Model):
1116
def __init__(self, name: str, predictor_host: str, protocol: str,
1217
use_ssl: bool, vectorstore_name: str = "vectorstore"):
1318
super().__init__(name)
19+
# KServe specific arguments
1420
self.name = name
1521
self.predictor_host = predictor_host
1622
self.protocol = protocol
@@ -22,7 +28,9 @@ def __init__(self, name: str, predictor_host: str, protocol: str,
2228
self.vectorstore_url = self._build_vectorstore_url()
2329

2430
def _get_namespace(self):
25-
return open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r").read().strip()
31+
return (open(
32+
"/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r")
33+
.read())
2634

2735
def _build_vectorstore_url(self):
2836
domain_name = "svc.cluster.local"
@@ -34,58 +42,35 @@ def _build_vectorstore_url(self):
3442
url = f"http://{svc}/v1/models/{model_name}:predict"
3543
return url
3644

45+
@property
46+
def _http_client(self):
47+
if self._http_client_instance is None:
48+
# No Authorization header needed
49+
self._http_client_instance = httpx.AsyncClient(verify=False) # Removed headers argument
50+
return self._http_client_instance
51+
3752
def preprocess(self, request: dict, headers: dict) -> dict:
3853
data = request["instances"][0]
3954
query = data["input"]
40-
num_docs = data.get("num_docs", 4)
41-
system_message = data.get("system", "You are an AI assistant.")
42-
instruction = data.get("instruction", "Answer the question using the context below.")
43-
4455
logger.info(f"Received question: {query}")
56+
num_docs = data.get("num_docs", 4)
4557
context = data.get("context", None)
46-
47-
# If no context is provided, retrieve documents from the vector store
48-
if not context:
49-
payload = {"instances": [{"input": query, "num_docs": num_docs}]}
50-
logger.info(f"Retrieving relevant documents from: {self.vectorstore_url}")
51-
52-
response = requests.post(self.vectorstore_url, json=payload, verify=False)
53-
response_data = response.json()
54-
55-
if response.status_code == 200 and "predictions" in response_data:
56-
context = "\n".join(response_data["predictions"])
57-
else:
58-
context = "No relevant documents found."
59-
60-
logger.info(f"Retrieved Context:\n{context}")
61-
62-
# 🔥 FIX: Ensure correct predictor URL
63-
predictor_url = f"http://{self.predictor_host}.svc.cluster.local/v1/chat/completions"
64-
logger.info(f"Sending request to LLM predictor at {predictor_url}")
65-
66-
llm_payload = {
67-
"model": "meta/llama-2-7b-chat",
68-
"messages": [
69-
{"role": "system", "content": system_message},
70-
{"role": "user", "content": f"{instruction}\n\nContext: {context}\n\nQuestion: {query}"}
71-
],
72-
"temperature": data.get("temperature", 0.5),
73-
"top_p": data.get("top_p", 1),
74-
"max_tokens": int(data.get("max_tokens", 256)),
75-
"stream": False
76-
}
77-
78-
llm_response = requests.post(predictor_url, json=llm_payload, verify=False)
79-
80-
if llm_response.status_code == 200:
81-
result = llm_response.json()["choices"][0]["message"]["content"]
82-
logger.info(f"LLM Response: {result}")
83-
return {"predictions": [result]}
58+
if context:
59+
logger.info(f"Received context: {context}")
60+
logger.info(f"Skipping retrieval step...")
61+
return {"instances": [data]}
8462
else:
85-
error_message = f"Error calling LLM predictor: {llm_response.status_code} - {llm_response.text}"
86-
logger.error(error_message)
87-
return {"predictions": [error_message]}
88-
63+
payload = {"instances":[{"input": query, "num_docs": num_docs}]}
64+
logger.info(
65+
f"Receiving relevant docs from: {self.vectorstore_url}")
66+
67+
response = requests.post(
68+
self.vectorstore_url, json=payload,
69+
verify=False)
70+
response = json.loads(response.text)
71+
context = "\n".join(response["predictions"])
72+
logger.info(f"Received documents: {context}")
73+
return {"instances": [{**data, **{"context": context}}]}
8974

9075
if __name__ == "__main__":
9176
parser = argparse.ArgumentParser(parents=[model_server.parser])
@@ -97,7 +82,7 @@ def preprocess(self, request: dict, headers: dict) -> dict:
9782
parser.add_argument(
9883
"--model_name", help="The name that the model is served under.")
9984
parser.add_argument(
100-
"--use_ssl", help="Use SSL for connecting to the predictor",
85+
"--use_ssl", help="Use ssl for connecting to the predictor",
10186
action='store_true')
10287
parser.add_argument("--vectorstore_name", default="vectorstore",
10388
required=False,
@@ -109,4 +94,4 @@ def preprocess(self, request: dict, headers: dict) -> dict:
10994
model = Transformer(
11095
args.model_name, args.predictor_host, args.protocol, args.use_ssl,
11196
args.vectorstore_name)
112-
ModelServer().start([model])
97+
ModelServer().start([model])

manifests/Inference/GPU/.ipynb_checkpoints/llm_inference_nvidia-checkpoint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ spec:
2121
transformer:
2222
timeout: 600
2323
containers:
24-
- image: chasechristensen/transformer-nvidia-nim:0.3
24+
- image: chasechristensen/transformer-nvidia-nim:0.4
2525
imagePullPolicy: Always
2626
resources:
2727
requests:

manifests/Inference/GPU/llm_inference_nvidia.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ spec:
2121
transformer:
2222
timeout: 600
2323
containers:
24-
- image: chasechristensen/transformer-nvidia-nim:0.3
24+
- image: chasechristensen/transformer-nvidia-nim:0.4
2525
imagePullPolicy: Always
2626
resources:
2727
requests:

notebooks/Tiledb_doc_prep.ipynb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@
213213
},
214214
{
215215
"cell_type": "code",
216-
"execution_count": 215,
216+
"execution_count": 216,
217217
"id": "f65b0d34-eeb4-44d3-be3e-98d033db8a58",
218218
"metadata": {},
219219
"outputs": [
@@ -248,7 +248,9 @@
248248
"id": "6d4b3e9c-9998-4049-875a-98c84f244ef7",
249249
"metadata": {},
250250
"outputs": [],
251-
"source": []
251+
"source": [
252+
"this works "
253+
]
252254
},
253255
{
254256
"cell_type": "code",

0 commit comments

Comments
 (0)