Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Light GraphRAG #4585

Merged
merged 17 commits into from
Jan 22, 2025
Prev Previous commit
Next Next commit
Refactor knowledge graph.
  • Loading branch information
KevinHuSh committed Jan 17, 2025
commit 9d63c4a92c03ceb6b9f0cd8180a782dddf7883b9
2 changes: 1 addition & 1 deletion api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def set():
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, arr[0], arr[1], not any(
d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))

v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
Expand Down
10 changes: 8 additions & 2 deletions graphrag/general/community_reports_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def __init__(
self._max_report_length = max_report_length or 1500

def __call__(self, graph: nx.Graph, callback: Callable | None = None):
for n in graph.nodes:
if graph.nodes[n].get("weight"):
continue
graph.nodes[n]["weight"] = 1

communities: dict[str, dict[str, list]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
res_str = []
Expand All @@ -64,10 +69,11 @@ def __call__(self, graph: nx.Graph, callback: Callable | None = None):
for level, comm in communities.items():
for cm_id, ents in comm.items():
weight = ents["weight"]
ent_df = pd.DataFrame(self._get_entity_(ents["nodes"]))#[{"entity": n, **graph.nodes[n]} for n in ents])
ents = ents["nodes"]
ent_df = pd.DataFrame(self._get_entity_(ents))#[{"entity": n, **graph.nodes[n]} for n in ents])
ent_df["entity"] = ent_df["entity_name"]
del ent_df["entity_name"]
rela_df = pd.DataFrame(self._get_relation_(ents, ents, 10000))
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
rela_df["source"] = rela_df["src_id"]
rela_df["target"] = rela_df["tgt_id"]
del rela_df["src_id"]
Expand Down
94 changes: 53 additions & 41 deletions graphrag/general/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
chunk_id
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock


class Dealer:
Expand Down Expand Up @@ -62,12 +63,13 @@ def __init__(self,
#description=rel["description"]
)

old_graph = get_graph(tenant_id, kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph])
with RedisDistributedLock(kb_id, 60*60):
old_graph = get_graph(tenant_id, kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph])

set_graph(self.graph)
set_graph(tenant_id, kb_id, self.graph)


class WithResolution(Dealer):
Expand All @@ -78,31 +80,35 @@ def __init__(self,
):
_, tenant = TenantService.get_by_id(tenant_id)
self.llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
self.graph = get_graph(tenant_id, kb_id)
if not self.graph:

with RedisDistributedLock(kb_id, 60*60):
self.graph = get_graph(tenant_id, kb_id)
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
if callback:
callback(-1, msg="Faild to fetch the graph.")
return

if callback:
callback(msg="Fetch the existing graph.")
er = EntityResolution(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id))
reso = er(self.graph)
self.graph = reso.graph
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if callback:
callback(-1, msg="Faild to fetch the graph.")
return
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
set_graph(tenant_id, kb_id, self.graph)

if callback:
callback(msg="Fetch the existing graph.")
er = EntityResolution(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id))
reso = er(self.graph)
self.graph = reso.graph
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if callback:
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
set_graph(tenant_id, kb_id, self.graph)
settings.retrievaler.delete({
settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"from_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
settings.retrievaler.delete({
settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"to_entity_kwd": reso.removed_entities
Expand All @@ -115,29 +121,36 @@ def __init__(self,
kb_id: str,
callback=None
):

self.community_structure = None
self.community_reports = None
_, tenant = TenantService.get_by_id(tenant_id)
self.llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
self.graph = get_graph(tenant_id, kb_id)
if not self.graph:
if callback:
callback(-1, msg="Faild to fetch the graph.")
return
if callback:
callback(msg="Fetch the existing graph.")

cr = CommunityReportsExtractor(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id))
cr = cr(self.graph, callback=callback)
self.community_structure = cr.structured_output
self.community_reports = cr.output
with RedisDistributedLock(kb_id, 60*60):
self.graph = get_graph(tenant_id, kb_id)
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
if callback:
callback(-1, msg="Faild to fetch the graph.")
return
if callback:
callback(msg="Fetch the existing graph.")

cr = CommunityReportsExtractor(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id))
cr = cr(self.graph, callback=callback)
self.community_structure = cr.structured_output
self.community_reports = cr.output
set_graph(tenant_id, kb_id, self.graph)

if callback:
callback(msg="Graph community extraction is done. Indexing {} reports.".format(cr.structured_output))

settings.retrievaler.delete({
settings.docStoreConn.delete({
"knowledge_graph_kwd": "community_report",
"kb_id": kb_id
}, search.index_name(tenant_id), kb_id)
Expand All @@ -156,4 +169,3 @@ def __init__(self,
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))

set_graph(tenant_id, kb_id, self.graph)
2 changes: 2 additions & 0 deletions graphrag/general/leiden.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
if not weights:
continue
max_weight = max(weights)
if max_weight == 0:
continue
for _, comm in result.items():
comm["weight"] /= max_weight

Expand Down
65 changes: 0 additions & 65 deletions graphrag/light/index.py

This file was deleted.

6 changes: 3 additions & 3 deletions graphrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def graph_merge(g1, g2):

for source, target, attr in g1.edges(data=True):
if g.has_edge(source, target):
g[source][target].update({"weight": attr["weight"]+1})
g[source][target].update({"weight": attr.get("weight", 0)+1})
continue
g.add_edge(source, target)#, **attr)

Expand Down Expand Up @@ -341,14 +341,14 @@ def get_graph(tenant_id, kb_id):
res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
for id in res.ids:
try:
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]))
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
except Exception:
continue


def set_graph(tenant_id, kb_id, graph):
chunk = {
"content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False,
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
indent=2),
"knowledge_graph_kwd": "graph",
"kb_id": kb_id
Expand Down
19 changes: 4 additions & 15 deletions rag/svr/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,20 +446,6 @@ def run_graphrag(row, callback=None):
)


def run_graphrag_community(row, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight"]):
chunks.append((d["id"], d["content_with_weight"]))

WithCommunity(
row["tenant_id"], str(row["kb_id"]), chunks,
row["parser_config"]["graphrag"]["language"],
row["parser_config"]["graphrag"]["entity_types"],
callback
)


def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
Expand Down Expand Up @@ -546,7 +532,10 @@ def do_handle_task(task):
elif task.get("task_type", "") == "graph_community":
start_ts = timer()
try:
run_graphrag_community(task, progress_callback)
WithCommunity(
task["tenant_id"], str(task["kb_id"]),
progress_callback
)
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
Expand Down
5 changes: 2 additions & 3 deletions rag/utils/es_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseI
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
v = re.sub(r"[']", " ", v)
v = re.sub(r"(['\n\r]|\\n)", " ", v)
scripts.append(f"ctx._source.{k}='{v}';")
elif isinstance(v, int):
scripts.append(f"ctx._source.{k}={v};")
Expand All @@ -362,7 +362,6 @@ def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseI
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
print("\n".join(scripts), "\n==============================\n")
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
Expand All @@ -373,7 +372,7 @@ def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseI
_ = ubq.execute()
return True
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e))
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
if re.search(r"(timeout|connection|conflict)", str(e).lower()):
continue
break
Expand Down
27 changes: 27 additions & 0 deletions rag/utils/redis_conn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import json
import time
import uuid

import valkey as redis
from rag import settings
Expand Down Expand Up @@ -262,3 +264,28 @@ def queue_info(self, queue, group_name) -> dict | None:


REDIS_CONN = RedisDB()


class RedisDistributedLock:
def __init__(self, lock_key, timeout=10):
self.lock_key = lock_key
self.lock_value = str(uuid.uuid4())
self.timeout = timeout

def acquire_lock(self):
end_time = time.time() + self.timeout
while time.time() < end_time:
if REDIS_CONN.REDIS.setnx(self.lock_key, self.lock_value):
return True
time.sleep(0.1)
return False

def release_lock(self):
if self.REDIS_CONN.REDIS.get(self.lock_key) == self.lock_value:
self.REDIS_CONN.REDIS.delete(self.lock_key)

def __enter__(self):
self.acquire_lock()

def __exit__(self, exception_type, exception_value, exception_traceback):
self.release_lock