Skip to content

Commit

Permalink
Add support NeMo Retriever Text Reranking NIM in O-RAN chatbot (NVIDI…
Browse files Browse the repository at this point in the history
…A#187)

* Add support for NeMo Retriever Text Reranking NIM in oran chatbot

* Add default reranker and NIM reranker configurations for oran chatbot
  • Loading branch information
sduttanv authored Sep 9, 2024
1 parent 9852997 commit c4240f0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
62 changes: 38 additions & 24 deletions community/oran-chatbot-multimodal/Multimodal_Assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from retriever.retriever import Retriever, get_relevant_docs, get_relevant_docs_mq
from utils.feedback import feedback_kwargs

from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIARerank
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
Expand Down Expand Up @@ -400,52 +400,66 @@ def load_config(cfg_arg):
if rag_type == 1:
augmented_queries = augment_multiple_query(transformed_query["text"])
queries = [transformed_query["text"]] + augmented_queries[2:]
print("Queries are = ", queries)
# print("Queries are = ", queries)
retrieved_documents = []
retrieved_metadatas = []
relevant_docs = []
for query in queries:
ret_docs,cons,srcs = get_relevant_docs(CORE_DIR, query)
for doc in ret_docs:
retrieved_documents.append(doc.page_content)
retrieved_metadatas.append(doc.metadata['source'])
relevant_docs.append(doc)
print("length of retrieved docs: ", len(retrieved_documents))
#Remove all duplicated documents and retain the original metadata
unique_documents = []
unique_documents_metadata = []
for document,source in zip(retrieved_documents,retrieved_metadatas):
unique_relevant_documents = []
for idx, (document,source) in enumerate(zip(retrieved_documents,retrieved_metadatas)):
if document not in unique_documents:
unique_documents.append(document)
unique_documents_metadata.append(source)
unique_relevant_documents.append(relevant_docs[idx])

if len(retrieved_documents) == 0:
context = ""
print("not context found context")
else:
print("length of unique docs: ", len(unique_documents))
#Instantiate the cross-encoder model and get scores for each retrieved document
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # ('BAAI/bge-reranker-large')('cross-encoder/ms-marco-MiniLM-L-6-v2')
pairs = [[prompt, doc] for doc in unique_documents]
scores = cross_encoder.predict(pairs)
#Sort the scores from highest to least
order_ids = np.argsort(scores)[::-1]
# print(order_ids)
#Instantiate the re-ranker model and get scores for each retrieved document
new_updated_documents = []
new_updated_sources = []
#Get the top 6 scores
if len(order_ids)>=10:
for i in range(10):
new_updated_documents.append(unique_documents[order_ids[i]])
new_updated_sources.append(unique_documents_metadata[order_ids[i]])
if not config_yaml['Reranker_NIM']:
print("\n\nReranking with Cross-encoder model: ", config_yaml['reranker_model'])
cross_encoder = CrossEncoder(config_yaml['reranker_model'])
pairs = [[prompt, doc] for doc in unique_documents]
scores = cross_encoder.predict(pairs)
#Sort the scores from highest to least
order_ids = np.argsort(scores)[::-1]
#Get the top 10 scores
if len(order_ids)>=10:
for i in range(10):
new_updated_documents.append(unique_documents[order_ids[i]])
new_updated_sources.append(unique_documents_metadata[order_ids[i]])
else:
for i in range(len(order_ids)):
new_updated_documents.append(unique_documents[order_ids[i]])
new_updated_sources.append(unique_documents_metadata[order_ids[i]])
else:
for i in range(len(order_ids)):
new_updated_documents.append(unique_documents[order_ids[i]])
new_updated_sources.append(unique_documents_metadata[order_ids[i]])
print("\n\nReranking with Retriever Text Reranking NIM model: ", config_yaml["reranker_model_name"])
# Initialize and connect to the running NeMo Retriever Text Reranking NIM
reranker = NVIDIARerank(model=config_yaml["reranker_model_name"],
base_url=config_yaml["reranker_api_endpoint_url"], top_n=10)
reranked_chunks = reranker.compress_documents(query=transformed_query["text"], documents=unique_relevant_documents)
for chunks in reranked_chunks:
metadata = chunks.metadata
page_content = chunks.page_content
new_updated_documents.append(page_content)
new_updated_sources.append(metadata['source'])

print(new_updated_sources)
print(len(new_updated_documents))
print("Reranking of completed for ", len(new_updated_documents), " chunks")

context = ""
# sources = ""
sources = {}
for doc in new_updated_documents:
context += doc + "\n\n"
Expand All @@ -455,7 +469,7 @@ def load_config(cfg_arg):
sources[src] = {"doc_content": sources[src]["doc_content"]+"\n\n"+new_updated_documents[i], "doc_metadata": src}
else:
sources[src] = {"doc_content": new_updated_documents[i], "doc_metadata": src}
print("length of source docs: ", len(sources))
print("Length of unique source docs: ", len(sources))
#Send the top 10 results along with the query to LLM

if rag_type == 2:
Expand Down Expand Up @@ -486,7 +500,7 @@ def load_config(cfg_arg):

print("length of unique docs: ", len(unique_documents))
#Instantiate the cross-encoder model and get scores for each retrieved document
cross_encoder = CrossEncoder('BAAI/bge-reranker-large') #('cross-encoder/ms-marco-MiniLM-L-6-v2')
cross_encoder = CrossEncoder(config_yaml['reranker_model'])
pairs = [[prompt, doc] for doc in unique_documents]
scores = cross_encoder.predict(pairs)
#Sort the scores from highest to least
Expand Down Expand Up @@ -544,7 +558,7 @@ def load_config(cfg_arg):

print("length of unique docs: ", len(unique_documents))
#Instantiate the cross-encoder model and get scores for each retrieved document
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') #('BAAI/bge-reranker-large')
cross_encoder = CrossEncoder(config_yaml['reranker_model'])
pairs = [[prompt, doc] for doc in unique_documents]
scores = cross_encoder.predict(pairs)
#Sort the scores from highest to least
Expand Down
6 changes: 5 additions & 1 deletion community/oran-chatbot-multimodal/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
nvidia_api_key: "nvapi--***"
## Set these to required models endpoints from NVIDIA NGC
llm_model: "mistralai/mixtral-8x7b-instruct-v0.1"
# Augmentation_model:
embedding_model: "nvidia/nv-embedqa-e5-v5"
reranker_model: "cross-encoder/ms-marco-MiniLM-L-6-v2"

NIM: false
nim_model_name: "meta/llama3-8b-instruct"
Expand All @@ -17,4 +17,8 @@ nrem_model_name: "nvidia/nv-embedqa-e5-v5"
nrem_api_endpoint_url: "http://localhost:8001/v1"
nrem_truncate: "END"

Reranker_NIM: false
reranker_model_name: "nvidia/nv-rerankqa-mistral-4b-v3"
reranker_api_endpoint_url: "http://localhost:8000/v1"

file_delete_password: "oranpwd"

0 comments on commit c4240f0

Please sign in to comment.