generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 76
Open
Description
Summary
This issue documents the performance impact of #101 switching shortest-path traversal to BFS for unit-weight edges in BYOKG-RAG.
On the demo graph (freebase_tiny_kg.csv), the change reduces mean latency and improves throughput by ~46% with identical retrieval output.
Benchmark setup
- Component:
byokg_rag.PathRetriever.retrieve - Graph:
byokg-rag/data/freebase_tiny_kg.csv - Workload: 100 random (source, answer) node pairs
- Iterations: 200 (after warmup)
Results
| Branch | mean (ms) | p50 (ms) | p95 (ms) | max (ms) | throughput (ops/s) | avg_context_items |
|---|---|---|---|---|---|---|
| main | 17.748 | 10.830 | 59.905 | 72.548 | 56.34 | 0.74 |
| PR | 12.134 | 7.021 | 45.945 | 62.159 | 82.41 | 0.74 |
- Mean latency ↓ 31.6%
- p50 ↓ 35.1%
- p95 ↓ 23.3%
- Throughput ↑ 46.3%
- Retrieved context unchanged (
avg_context_itemsidentical)
These results are consistent with expectations: for unit-weight edges, BFS avoids the priority-queue overhead of Dijkstra while preserving shortest-path correctness.
Reproduction
$ git checkout <PR-branch>
$ python byokg-rag/src/graphrag_toolkit/byokg_rag/bench_byokg.py \
--csv examples/byokg-rag/data/freebase_tiny_kg.csv \
--iters 200 --pairs 100
$ git checkout main
$ python byokg-rag/src/graphrag_toolkit/byokg_rag/bench_byokg.py \
--csv examples/byokg-rag/data/freebase_tiny_kg.csv \
--iters 200 --pairs 100Benchmark script
byokg-rag/src/graphrag_toolkit/byokg_rag/bench_byokg.py
from __future__ import annotations
import argparse
import random
import statistics
import time
from dataclasses import dataclass
from typing import List, Sequence, Tuple
from graphrag_toolkit.byokg_rag.graphstore import LocalKGStore
from graphrag_toolkit.byokg_rag.graph_retrievers import (
GTraversal,
PathRetriever,
PathVerbalizer,
TripletGVerbalizer,
)
@dataclass
class Stats:
n: int
mean_ms: float
p50_ms: float
p95_ms: float
min_ms: float
max_ms: float
qps: float
def percentile(values: Sequence[float], p: float) -> float:
if not values:
return float("nan")
xs = sorted(values)
k = (len(xs) - 1) * p
f = int(k)
c = min(f + 1, len(xs) - 1)
if f == c:
return xs[f]
return xs[f] + (xs[c] - xs[f]) * (k - f)
def timed(fn, *args, **kwargs) -> Tuple[float, object]:
t0 = time.perf_counter()
out = fn(*args, **kwargs)
t1 = time.perf_counter()
return (t1 - t0), out
def compute_stats(durations_s: List[float]) -> Stats:
ms = [d * 1000.0 for d in durations_s]
total_s = sum(durations_s)
return Stats(
n=len(ms),
mean_ms=statistics.mean(ms),
p50_ms=percentile(ms, 0.50),
p95_ms=percentile(ms, 0.95),
min_ms=min(ms),
max_ms=max(ms),
qps=(len(ms) / total_s) if total_s > 0 else float("inf"),
)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--csv", required=True, help="Path to KG CSV (e.g., data/freebase_tiny_kg.csv)")
ap.add_argument("--pairs", type=int, default=200, help="How many (entity, candidate-answer) pairs to sample")
ap.add_argument("--iters", type=int, default=200, help="How many benchmark iterations to run")
ap.add_argument("--warmup", type=int, default=20, help="Warmup iterations (excluded from stats)")
ap.add_argument("--seed", type=int, default=7)
# Metapaths: use simple patterns to exercise shortest-path / BFS logic.
# You can add more patterns that match your schema.
ap.add_argument(
"--metapaths",
nargs="*",
default=["*"],
help="Metapaths like: Person->bornIn->City City->locatedIn->Country (space-separated tokens per path not supported here)",
)
ap.add_argument("--scoring", action="store_true", help="Also benchmark GraphScoringRetriever (requires reranker model)")
ap.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Device for reranker if --scoring")
args = ap.parse_args()
random.seed(args.seed)
# -------- Load graph --------
graph_store = LocalKGStore()
graph_store.read_from_csv(args.csv)
nodes = list(graph_store.nodes())
if len(nodes) < 10:
raise RuntimeError(f"Graph has too few nodes: {len(nodes)}")
# -------- Build traversal + retrievers --------
traversal = GTraversal(graph_store)
path_retriever = PathRetriever(
graph_traversal=traversal,
path_verbalizer=PathVerbalizer(),
)
scoring_retriever = None
if args.scoring:
from graphrag_toolkit.byokg_rag.graph_retrievers import GraphScoringRetriever, LocalGReranker
reranker = LocalGReranker(
model_name="BAAI/bge-reranker-v2-m3",
topk=10,
device=args.device,
)
scoring_retriever = GraphScoringRetriever(
graph_traversal=traversal,
graph_verbalizer=TripletGVerbalizer(),
graph_reranker=reranker,
)
tokenized_metapaths: List[List[str]] = []
for p in args.metapaths:
if p == "*" or "->" not in p:
tokenized_metapaths.append([p])
else:
tokenized_metapaths.append([t.strip() for t in p.split("->") if t.strip()])
workload: List[Tuple[List[str], List[str]]] = []
for _ in range(args.pairs):
src = random.choice(nodes)
ans = random.choice(nodes)
if src == ans:
continue
workload.append(([src], [ans]))
if not workload:
raise RuntimeError("Failed to build workload pairs.")
# -------- Benchmark PathRetriever --------
for i in range(args.warmup):
srcs, answers = workload[i % len(workload)]
path_retriever.retrieve(srcs, tokenized_metapaths, answers)
durations = []
total_ctx = 0
for i in range(args.iters):
srcs, answers = workload[i % len(workload)]
dt, ctx = timed(path_retriever.retrieve, srcs, tokenized_metapaths, answers)
durations.append(dt)
try:
total_ctx += len(ctx)
except Exception:
pass
s = compute_stats(durations)
print("\n=== PathRetriever.retrieve ===")
print(f"n={s.n} mean={s.mean_ms:.3f} ms p50={s.p50_ms:.3f} ms p95={s.p95_ms:.3f} ms")
print(f"min={s.min_ms:.3f} ms max={s.max_ms:.3f} ms throughput={s.qps:.2f} ops/s")
print(f"avg_context_items={(total_ctx / s.n):.2f}")
if scoring_retriever is not None:
for i in range(args.warmup):
srcs, _answers = workload[i % len(workload)]
scoring_retriever.retrieve("benchmark query", srcs)
durations2 = []
total_ctx2 = 0
for i in range(args.iters):
srcs, _answers = workload[i % len(workload)]
dt, ctx = timed(scoring_retriever.retrieve, "benchmark query", srcs)
durations2.append(dt)
total_ctx2 += len(ctx)
s2 = compute_stats(durations2)
print("\n=== GraphScoringRetriever.retrieve ===")
print(f"n={s2.n} mean={s2.mean_ms:.3f} ms p50={s2.p50_ms:.3f} ms p95={s2.p95_ms:.3f} ms")
print(f"min={s2.min_ms:.3f} ms max={s2.max_ms:.3f} ms throughput={s2.qps:.2f} ops/s")
print(f"avg_context_items={(total_ctx2 / s2.n):.2f}")
if __name__ == "__main__":
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels