Skip to content

Hybrid search #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ async def post_processing(uri=Form(None), userName=Form(None), password=Form(Non
logger.log_struct(josn_obj)
logging.info(f'Updated KNN Graph')
if "create_fulltext_index" in tasks:
await asyncio.to_thread(create_fulltext, uri=uri, username=userName, password=password, database=database)
await asyncio.to_thread(create_fulltext, uri=uri, username=userName, password=password, database=database,type="entities")
await asyncio.to_thread(create_fulltext, uri=uri, username=userName, password=password, database=database,type="keyword")
josn_obj = {'api_name': 'post_processing/create_fulltext_index', 'db_url': uri, 'logging_time': formatted_time(datetime.now(timezone.utc))}
logger.log_struct(josn_obj)
logging.info(f'Full Text index created')
Expand Down
46 changes: 34 additions & 12 deletions backend/src/QA_integration_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,37 @@
EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)


def get_neo4j_retriever(graph, retrieval_query,document_names,index_name="vector", search_k=CHAT_SEARCH_KWARG_K, score_threshold=CHAT_SEARCH_KWARG_SCORE_THRESHOLD):
def get_neo4j_retriever(graph, retrieval_query,document_names,mode,index_name="vector",keyword_index="keyword", search_k=CHAT_SEARCH_KWARG_K, score_threshold=CHAT_SEARCH_KWARG_SCORE_THRESHOLD):
try:
neo_db = Neo4jVector.from_existing_index(
embedding=EMBEDDING_FUNCTION,
index_name=index_name,
retrieval_query=retrieval_query,
graph=graph
)
logging.info(f"Successfully retrieved Neo4jVector index '{index_name}'")
if mode == "hybrid":
# neo_db = Neo4jVector.from_existing_graph(
# embedding=EMBEDDING_FUNCTION,
# index_name=index_name,
# retrieval_query=retrieval_query,
# graph=graph,
# search_type="hybrid",
# node_label="Chunk",
# embedding_node_property="embedding",
# text_node_properties=["text"]
# # keyword_index_name=keyword_index
# )
neo_db = Neo4jVector.from_existing_index(
embedding=EMBEDDING_FUNCTION,
index_name=index_name,
retrieval_query=retrieval_query,
graph=graph,
search_type="hybrid",
keyword_index_name=keyword_index
)
logging.info(f"Successfully retrieved Neo4jVector index '{index_name}' and keyword index '{keyword_index}'")
else:
neo_db = Neo4jVector.from_existing_index(
embedding=EMBEDDING_FUNCTION,
index_name=index_name,
retrieval_query=retrieval_query,
graph=graph
)
logging.info(f"Successfully retrieved Neo4jVector index '{index_name}'")
document_names= list(map(str.strip, json.loads(document_names)))
if document_names:
retriever = neo_db.as_retriever(search_type="similarity_score_threshold",search_kwargs={'k': search_k, "score_threshold": score_threshold,'filter':{'fileName': {'$in': document_names}}})
Expand Down Expand Up @@ -232,13 +254,13 @@ def clear_chat_history(graph,session_id):
"user": "chatbot"
}

def setup_chat(model, graph, session_id, document_names,retrieval_query):
def setup_chat(model, graph, document_names,retrieval_query,mode):
start_time = time.time()
if model in ["diffbot"]:
model = "openai-gpt-4o"
llm,model_name = get_llm(model)
logging.info(f"Model called in chat {model} and model version is {model_name}")
retriever = get_neo4j_retriever(graph=graph,retrieval_query=retrieval_query,document_names=document_names)
retriever = get_neo4j_retriever(graph=graph,retrieval_query=retrieval_query,document_names=document_names,mode=mode)
doc_retriever = create_document_retriever_chain(llm, retriever)
chat_setup_time = time.time() - start_time
logging.info(f"Chat setup completed in {chat_setup_time:.2f} seconds")
Expand Down Expand Up @@ -357,10 +379,10 @@ def QA_RAG(graph, model, question, document_names,session_id, mode):
else:
retrieval_query = VECTOR_GRAPH_SEARCH_QUERY.format(no_of_entites=VECTOR_GRAPH_SEARCH_ENTITY_LIMIT)

llm, doc_retriever, model_version = setup_chat(model, graph, session_id, document_names,retrieval_query)
llm, doc_retriever, model_version = setup_chat(model, graph, document_names,retrieval_query,mode)

docs = retrieve_documents(doc_retriever, messages)

if docs:
content, result, total_tokens = process_documents(docs, question, messages, llm,model)
else:
Expand Down
40 changes: 26 additions & 14 deletions backend/src/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX entities FOR (n{labels_str}) ON EACH [n.id, n.description];"
FILTER_LABELS = ["Chunk","Document"]

def create_fulltext(uri, username, password, database):

HYBRID_SEARCH_INDEX_DROP_QUERY = "DROP INDEX keyword IF EXISTS;"
HYBRID_SEARCH_FULL_TEXT_QUERY = "CREATE FULLTEXT INDEX keyword FOR (n:Chunk) ON EACH [n.text]"

def create_fulltext(uri, username, password, database,type):
start_time = time.time()
logging.info("Starting the process of creating a full-text index.")

Expand All @@ -26,28 +30,37 @@ def create_fulltext(uri, username, password, database):
with driver.session() as session:
try:
start_step = time.time()
session.run(DROP_INDEX_QUERY)
if type == "entities":
drop_query = DROP_INDEX_QUERY
else:
drop_query = HYBRID_SEARCH_INDEX_DROP_QUERY
session.run(drop_query)
logging.info(f"Dropped existing index (if any) in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to drop index: {e}")
return
try:
start_step = time.time()
result = session.run(LABELS_QUERY)
labels = [record["label"] for record in result]

for label in FILTER_LABELS:
if label in labels:
labels.remove(label)

labels_str = ":" + "|".join([f"`{label}`" for label in labels])
logging.info(f"Fetched labels in {time.time() - start_step:.2f} seconds.")
if type == "entities":
start_step = time.time()
result = session.run(LABELS_QUERY)
labels = [record["label"] for record in result]

for label in FILTER_LABELS:
if label in labels:
labels.remove(label)

labels_str = ":" + "|".join([f"`{label}`" for label in labels])
logging.info(f"Fetched labels in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to fetch labels: {e}")
return
try:
start_step = time.time()
session.run(FULL_TEXT_QUERY.format(labels_str=labels_str))
if type == "entities":
fulltext_query = FULL_TEXT_QUERY.format(labels_str=labels_str)
else:
fulltext_query = HYBRID_SEARCH_FULL_TEXT_QUERY
session.run(fulltext_query)
logging.info(f"Created full-text index in {time.time() - start_step:.2f} seconds.")
except Exception as e:
logging.error(f"Failed to create full-text index: {e}")
Expand All @@ -59,7 +72,6 @@ def create_fulltext(uri, username, password, database):
logging.info("Driver closed.")
logging.info(f"Process completed in {time.time() - start_time:.2f} seconds.")


def create_entity_embedding(graph:Neo4jGraph):
rows = fetch_entities_for_embedding(graph)
for i in range(0, len(rows), 1000):
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/utils/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export const defaultLLM = llms?.includes('openai-gpt-3.5')
? 'gemini-1.0-pro'
: 'diffbot';
export const chatModes =
process.env?.CHAT_MODES?.trim() != '' ? process.env.CHAT_MODES?.split(',') : ['vector', 'graph', 'graph+vector'];
process.env?.CHAT_MODES?.trim() != '' ? process.env.CHAT_MODES?.split(',') : ['vector', 'graph', 'graph+vector','hybrid'];
export const chunkSize = process.env.CHUNK_SIZE ? parseInt(process.env.CHUNK_SIZE) : 1 * 1024 * 1024;
export const timeperpage = process.env.TIME_PER_PAGE ? parseInt(process.env.TIME_PER_PAGE) : 50;
export const timePerByte = 0.2;
Expand Down