Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260308082306195595.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add Top-K and min co-occurrence filters to NLP edge extraction to prevent O(N^2) relationship explosion"
}
2 changes: 2 additions & 0 deletions packages/graphrag/graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ class ExtractGraphNLPDefaults:
text_analyzer: TextAnalyzerDefaults = field(default_factory=TextAnalyzerDefaults)
concurrent_requests: int = 25
async_mode: AsyncType = AsyncType.Threaded
max_entities_per_chunk: int = 0
min_co_occurrence: int = 1


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,16 @@ class ExtractGraphNLPConfig(BaseModel):
description="The async mode to use.",
default=graphrag_config_defaults.extract_graph_nlp.async_mode,
)
max_entities_per_chunk: int = Field(
description="Maximum number of noun-phrase entities to retain per text chunk "
"when building co-occurrence edges. Entities are ranked by global frequency "
"and only the top-K are paired, reducing edges from O(N^2) to O(K^2). "
"Set to 0 to disable (keep all entities).",
default=graphrag_config_defaults.extract_graph_nlp.max_entities_per_chunk,
)
min_co_occurrence: int = Field(
description="Minimum number of text units in which an edge must co-occur "
"to be retained. Edges appearing in fewer text units are discarded as "
"likely coincidental. Set to 1 to disable filtering.",
default=graphrag_config_defaults.extract_graph_nlp.min_co_occurrence,
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ async def build_noun_graph(
text_analyzer: BaseNounPhraseExtractor,
normalize_edge_weights: bool,
cache: Cache,
max_entities_per_chunk: int = 0,
min_co_occurrence: int = 1,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Build a noun graph from text units."""
title_to_ids = await _extract_nodes(
Expand All @@ -49,6 +51,8 @@ async def build_noun_graph(
title_to_ids,
nodes_df=nodes_df,
normalize_edge_weights=normalize_edge_weights,
max_entities_per_chunk=max_entities_per_chunk,
min_co_occurrence=min_co_occurrence,
)
return (nodes_df, edges_df)

Expand Down Expand Up @@ -100,27 +104,52 @@ def _extract_edges(
title_to_ids: dict[str, list[str]],
nodes_df: pd.DataFrame,
normalize_edge_weights: bool = True,
max_entities_per_chunk: int = 0,
min_co_occurrence: int = 1,
) -> pd.DataFrame:
"""Build co-occurrence edges between noun phrases.

Nodes that appear in the same text unit are connected.

Two optional filters reduce O(N^2) edge explosion in
entity-dense corpora (e.g. scientific/technical text):

* ``max_entities_per_chunk`` – When > 0, only the K most
globally-frequent entities per text unit are paired,
capping edges at C(K,2) instead of C(N,2).
* ``min_co_occurrence`` – When > 1, edges that appear in
fewer than this many text units are discarded, removing
coincidental co-occurrences.

Returns edges with schema [source, target, weight, text_unit_ids].
"""
if not title_to_ids:
return pd.DataFrame(
columns=["source", "target", "weight", "text_unit_ids"],
)

entity_freq: dict[str, int] = {
t: len(ids) for t, ids in title_to_ids.items()
}

text_unit_to_titles: dict[str, list[str]] = defaultdict(list)
for title, tu_ids in title_to_ids.items():
for tu_id in tu_ids:
text_unit_to_titles[tu_id].append(title)

edge_map: dict[tuple[str, str], list[str]] = defaultdict(list)
for tu_id, titles in text_unit_to_titles.items():
if len(titles) < 2:
unique_titles = sorted(set(titles))
if len(unique_titles) < 2:
continue
for pair in combinations(sorted(set(titles)), 2):
if max_entities_per_chunk > 0 and len(unique_titles) > max_entities_per_chunk:
unique_titles = sorted(
unique_titles,
key=lambda t: entity_freq.get(t, 0),
reverse=True,
)[:max_entities_per_chunk]
unique_titles.sort()
for pair in combinations(unique_titles, 2):
edge_map[pair].append(tu_id)

records = [
Expand All @@ -131,7 +160,17 @@ def _extract_edges(
"text_unit_ids": tu_ids,
}
for (src, tgt), tu_ids in edge_map.items()
if len(tu_ids) >= min_co_occurrence
]

if len(records) < len(edge_map):
logger.info(
"Edge co-occurrence filter: %d -> %d edges (min_co_occurrence=%d)",
len(edge_map),
len(records),
min_co_occurrence,
)

edges_df = pd.DataFrame(
records,
columns=["source", "target", "weight", "text_unit_ids"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ async def run_workflow(
relationships_table=relationships_table,
text_analyzer=text_analyzer,
normalize_edge_weights=(config.extract_graph_nlp.normalize_edge_weights),
max_entities_per_chunk=(
config.extract_graph_nlp.max_entities_per_chunk
),
min_co_occurrence=config.extract_graph_nlp.min_co_occurrence,
)

logger.info("Workflow completed: extract_graph_nlp")
Expand All @@ -65,13 +69,17 @@ async def extract_graph_nlp(
relationships_table: Table,
text_analyzer: BaseNounPhraseExtractor,
normalize_edge_weights: bool,
max_entities_per_chunk: int = 0,
min_co_occurrence: int = 1,
) -> dict[str, list[dict[str, Any]]]:
"""Extract noun-phrase graph and stream results to output tables."""
extracted_nodes, extracted_edges = await build_noun_graph(
text_units_table,
text_analyzer=text_analyzer,
normalize_edge_weights=normalize_edge_weights,
cache=cache,
max_entities_per_chunk=max_entities_per_chunk,
min_co_occurrence=min_co_occurrence,
)

if len(extracted_nodes) == 0:
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def assert_extract_graph_nlp_configs(
assert actual.normalize_edge_weights == expected.normalize_edge_weights
assert_text_analyzer_configs(actual.text_analyzer, expected.text_analyzer)
assert actual.concurrent_requests == expected.concurrent_requests
assert actual.max_entities_per_chunk == expected.max_entities_per_chunk
assert actual.min_co_occurrence == expected.min_co_occurrence


def assert_prune_graph_configs(
Expand Down
Loading
Loading