Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Light GraphRAG #4585

Merged
merged 17 commits into from
Jan 22, 2025
Prev Previous commit
Next Next commit
Refactor search part for graphrag
  • Loading branch information
KevinHuSh committed Jan 21, 2025
commit 68ae8936405d20f3320d6f159626ad78c41de3c3
36 changes: 35 additions & 1 deletion api/apps/kb_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json

from flask import request
from flask_login import login_required, current_user

Expand Down Expand Up @@ -270,4 +272,36 @@ def rename_tags(kb_id):
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)
return get_json_result(data=True)


@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
def knowledge_graph(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
req = {
"kb_id": [kb_id],
"knowledge_graph_kwd": ["graph"]
}
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:1]:
ty = sres.field[id]["knowledge_graph_kwd"]
try:
content_json = json.loads(sres.field[id]["content_with_weight"])
except Exception:
continue

obj[ty] = content_json

if "nodes" in obj["graph"]:
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
if "edges" in obj["graph"]:
obj["graph"]["edges"] = sorted(obj["graph"]["edges"], key=lambda x: x.get("weight", 0), reverse=True)[:128]
return get_json_result(data=obj)
2 changes: 1 addition & 1 deletion api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def init_llm_factory():
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
TenantService.filter_update([1 == 1], {
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"})
## insert openai two embedding models to the current openai user.
# print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
Expand Down
47 changes: 24 additions & 23 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def insert(cls, doc):
def remove_document(cls, doc, tenant_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
cls.clear_chunk_num(doc.id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "exist": "knowledge_graph_kwd", "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "exists": "knowledge_graph_kwd", "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
return cls.delete_by_id(doc.id)

@classmethod
Expand Down Expand Up @@ -365,6 +370,12 @@ def begin2parse(cls, docid):
@classmethod
@DB.connection_context()
def update_progress(cls):
MSG = {
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
"graphrag": "Start Graph Extraction",
"graph_resolution": "Start Graph Resolution",
"graph_community": "Start Graph Community Reports Generation"
}
docs = cls.get_unfinished_docs()
for d in docs:
try:
Expand All @@ -390,31 +401,27 @@ def update_progress(cls):
prg = -1
status = TaskStatus.FAIL.value
elif finished:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
" raptor") < 0:
queue_raptor_o_graphrag_tasks(d, "raptor")
m = "\n".join(sorted(msg))
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ RAPTOR -------")
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and d["progress_msg"].lower().find(" graphrag ") < 0:
queue_raptor_o_graphrag_tasks(d, "graphrag")
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ GraphRAG -------")
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("resolution") \
and d["progress_msg"].lower().find(" graph resolution ") < 0:
queue_raptor_o_graphrag_tasks(d, "graph_resolution")
and m.find(MSG["graph_resolution"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ Graph Resolution -------")
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("community") \
and d["progress_msg"].lower().find(" graph community ") < 0:
queue_raptor_o_graphrag_tasks(d, "graph_community")
and m.find(MSG["graph_community"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ Graph Community Detection-------")
else:
status = TaskStatus.DONE.value

msg = "\n".join(msg)
msg = "\n".join(sorted(msg))
info = {
"process_duation": datetime.timestamp(
datetime.now()) -
Expand Down Expand Up @@ -446,32 +453,26 @@ def do_cancel(cls, doc_id):
return False


def queue_raptor_o_graphrag_tasks(doc, ty="raptor"):
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
hasher.update(str(chunking_config[field]).encode("utf-8"))

msg = {
"raptor": "Start to do RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
"graphrag": "Start to do Graph Extraction",
"graph_resolution": "Start to do Graph Resolution",
"graph_community": "Start to do Graph Community Detection"
}

def new_task():
nonlocal doc
return {
"id": get_uuid(),
"doc_id": doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"progress_msg": msg[ty]
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
}

task = new_task()
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
task["task_type"] = ty
Expand Down
22 changes: 17 additions & 5 deletions api/db/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 1000)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
Expand All @@ -194,7 +194,7 @@ def update_progress(cls, id, info):
with DB.lock("update_progress", -1):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 1000)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
Expand Down Expand Up @@ -243,6 +243,10 @@ def new_task():
for task in parse_task_array:
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
if field == "parser_config":
for k in ["raptor", "graphrag"]:
if k in chunking_config[field]:
del chunking_config[field][k]
hasher.update(str(chunking_config[field]).encode("utf-8"))
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
Expand Down Expand Up @@ -278,18 +282,26 @@ def new_task():
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
idx = bisect.bisect_left(prev_tasks, (task.get("from_page", 0), task.get("digest", "")),
key=lambda x: (x.get("from_page", 0), x.get("digest", "")))
idx = 0
while idx < len(prev_tasks):
prev_task = prev_tasks[idx]
if prev_task.get("from_page", 0) == task.get("from_page", 0) \
and prev_task.get("digest", 0) == task.get("digest", ""):
break

if idx >= len(prev_tasks):
return 0
prev_task = prev_tasks[idx]
if prev_task["progress"] < 1.0 or prev_task["digest"] != task["digest"] or not prev_task["chunk_ids"]:
if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
return 0
task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0
if "from_page" in task and "to_page" in task:
if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
else:
task["progress_msg"] = ""
task["progress_msg"] += "reused previous task's chunks."
task["progress_msg"] = " ".join(
[datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
prev_task["chunk_ids"] = ""

return len(task["chunk_ids"].split())
3 changes: 0 additions & 3 deletions graphrag/entity_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = No
))
graph.remove_node(remove_node)

for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])

return EntityResolutionResult(
graph=graph,
removed_entities=removed_entities
Expand Down
9 changes: 4 additions & 5 deletions graphrag/general/community_reports_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ def __init__(
self._max_report_length = max_report_length or 1500

def __call__(self, graph: nx.Graph, callback: Callable | None = None):
for n in graph.nodes:
if graph.nodes[n].get("weight"):
continue
graph.nodes[n]["weight"] = 1
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])

communities: dict[str, dict[str, list]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
Expand All @@ -67,10 +65,11 @@ def __call__(self, graph: nx.Graph, callback: Callable | None = None):
over, token_count = 0, 0
st = timer()
for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}")
for cm_id, ents in comm.items():
weight = ents["weight"]
ents = ents["nodes"]
ent_df = pd.DataFrame(self._get_entity_(ents))#[{"entity": n, **graph.nodes[n]} for n in ents])
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents])
ent_df["entity"] = ent_df["entity_name"]
del ent_df["entity_name"]
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
Expand Down
2 changes: 1 addition & 1 deletion graphrag/general/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

GRAPH_FIELD_SEP = "<SEP>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
ENTITY_EXTRACTION_MAX_GLEANINGS = 1
ENTITY_EXTRACTION_MAX_GLEANINGS = 2


class Extractor:
Expand Down
49 changes: 29 additions & 20 deletions graphrag/general/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
from functools import reduce, partial
import networkx as nx

from api import settings
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
from graphrag.entity_resolution import EntityResolution
from graphrag.general.extractor import Extractor
from graphrag.general.graph_extractor import GraphExtractor, DEFAULT_ENTITY_TYPES
from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES
from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
chunk_id
chunk_id, update_nodes_pagerank_nhop_neighbour
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock

Expand Down Expand Up @@ -55,7 +53,7 @@ def __init__(self,
ents, rels = ext(chunks, callback)
self.graph = nx.Graph()
for en in ents:
self.graph.add_node(en["entity_name"])#, entity_type=en["entity_type"], description=en["description"])
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])

for rel in rels:
self.graph.add_edge(
Expand All @@ -70,7 +68,7 @@ def __init__(self,
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph])

update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
set_graph(tenant_id, kb_id, self.graph)


Expand Down Expand Up @@ -105,6 +103,7 @@ def __init__(self,
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if callback:
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
set_graph(tenant_id, kb_id, self.graph)

settings.docStoreConn.delete({
Expand All @@ -117,6 +116,11 @@ def __init__(self,
"kb_id": kb_id,
"to_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
settings.docStoreConn.delete({
"knowledge_graph_kwd": "entity",
"kb_id": kb_id,
"entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)


class WithCommunity(Dealer):
Expand Down Expand Up @@ -154,29 +158,34 @@ def __init__(self,
set_graph(tenant_id, kb_id, self.graph)

if callback:
callback(msg="Graph community extraction is done. Indexing {} reports.".format(cr.structured_output))
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))

settings.docStoreConn.delete({
"knowledge_graph_kwd": "community_report",
"kb_id": kb_id
}, search.index_name(tenant_id), kb_id)

for community, desc in zip(cr.structured_output, cr.output):
for stru, rep in zip(self.community_structure, self.community_reports):
obj = {
"report": rep,
"evidences": "\n".join([f["explanation"] for f in stru["findings"]])
}
chunk = {
"title_tks": rag_tokenizer.tokenize(community["title"]),
"content_with_weight": desc,
"content_ltks": rag_tokenizer.tokenize(desc),
"docnm_kwd": stru["title"],
"title_tks": rag_tokenizer.tokenize(stru["title"]),
"content_with_weight": json.dumps(obj, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]),
"knowledge_graph_kwd": "community_report",
"weight_flt": community["weight"],
"entities_kwd": community["entities"],
"important_kwd": community["entities"],
"weight_flt": stru["weight"],
"entities_kwd": stru["entities"],
"important_kwd": stru["entities"],
"kb_id": kb_id
}
try:
ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
except Exception as e:
logging.exception(f"Fail to embed entity relation: {e}")
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
#try:
# ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
#except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}")
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))

2 changes: 1 addition & 1 deletion graphrag/light/graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
example_number: int = 3,
example_number: int = 2,
max_gleanings: int | None = None,
):
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
Expand Down
Loading