Skip to content

Performance improvement: BFS-based shortest path in BYOKG PathRetriever #110

@hongjun7

Description

@hongjun7

Summary

This issue documents the performance impact of #101 switching shortest-path traversal to BFS for unit-weight edges in BYOKG-RAG.
On the demo graph (freebase_tiny_kg.csv), the change reduces mean latency and improves throughput by ~46% with identical retrieval output.

Benchmark setup

  • Component: byokg_rag.PathRetriever.retrieve
  • Graph: byokg-rag/data/freebase_tiny_kg.csv
  • Workload: 100 random (source, answer) node pairs
  • Iterations: 200 (after warmup)

Results

Branch mean (ms) p50 (ms) p95 (ms) max (ms) throughput (ops/s) avg_context_items
main 17.748 10.830 59.905 72.548 56.34 0.74
PR 12.134 7.021 45.945 62.159 82.41 0.74
  • Mean latency ↓ 31.6%
  • p50 ↓ 35.1%
  • p95 ↓ 23.3%
  • Throughput ↑ 46.3%
  • Retrieved context unchanged (avg_context_items identical)

These results are consistent with expectations: for unit-weight edges, BFS avoids the priority-queue overhead of Dijkstra while preserving shortest-path correctness.

Reproduction

$ git checkout <PR-branch>
$ python byokg-rag/src/graphrag_toolkit/byokg_rag/bench_byokg.py \
    --csv examples/byokg-rag/data/freebase_tiny_kg.csv \
    --iters 200 --pairs 100

$ git checkout main
$ python byokg-rag/src/graphrag_toolkit/byokg_rag/bench_byokg.py \
    --csv examples/byokg-rag/data/freebase_tiny_kg.csv \
    --iters 200 --pairs 100

Benchmark script

byokg-rag/src/graphrag_toolkit/byokg_rag/bench_byokg.py
from __future__ import annotations

import argparse
import random
import statistics
import time
from dataclasses import dataclass
from typing import List, Sequence, Tuple

from graphrag_toolkit.byokg_rag.graphstore import LocalKGStore
from graphrag_toolkit.byokg_rag.graph_retrievers import (
    GTraversal,
    PathRetriever,
    PathVerbalizer,
    TripletGVerbalizer,
)


@dataclass
class Stats:
    n: int
    mean_ms: float
    p50_ms: float
    p95_ms: float
    min_ms: float
    max_ms: float
    qps: float


def percentile(values: Sequence[float], p: float) -> float:
    if not values:
        return float("nan")
    xs = sorted(values)
    k = (len(xs) - 1) * p
    f = int(k)
    c = min(f + 1, len(xs) - 1)
    if f == c:
        return xs[f]
    return xs[f] + (xs[c] - xs[f]) * (k - f)


def timed(fn, *args, **kwargs) -> Tuple[float, object]:
    t0 = time.perf_counter()
    out = fn(*args, **kwargs)
    t1 = time.perf_counter()
    return (t1 - t0), out


def compute_stats(durations_s: List[float]) -> Stats:
    ms = [d * 1000.0 for d in durations_s]
    total_s = sum(durations_s)
    return Stats(
        n=len(ms),
        mean_ms=statistics.mean(ms),
        p50_ms=percentile(ms, 0.50),
        p95_ms=percentile(ms, 0.95),
        min_ms=min(ms),
        max_ms=max(ms),
        qps=(len(ms) / total_s) if total_s > 0 else float("inf"),
    )


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--csv", required=True, help="Path to KG CSV (e.g., data/freebase_tiny_kg.csv)")
    ap.add_argument("--pairs", type=int, default=200, help="How many (entity, candidate-answer) pairs to sample")
    ap.add_argument("--iters", type=int, default=200, help="How many benchmark iterations to run")
    ap.add_argument("--warmup", type=int, default=20, help="Warmup iterations (excluded from stats)")
    ap.add_argument("--seed", type=int, default=7)

    # Metapaths: use simple patterns to exercise shortest-path / BFS logic.
    # You can add more patterns that match your schema.
    ap.add_argument(
        "--metapaths",
        nargs="*",
        default=["*"],
        help="Metapaths like: Person->bornIn->City City->locatedIn->Country (space-separated tokens per path not supported here)",
    )

    ap.add_argument("--scoring", action="store_true", help="Also benchmark GraphScoringRetriever (requires reranker model)")
    ap.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Device for reranker if --scoring")

    args = ap.parse_args()
    random.seed(args.seed)

    # -------- Load graph --------
    graph_store = LocalKGStore()
    graph_store.read_from_csv(args.csv)
    nodes = list(graph_store.nodes())

    if len(nodes) < 10:
        raise RuntimeError(f"Graph has too few nodes: {len(nodes)}")

    # -------- Build traversal + retrievers --------
    traversal = GTraversal(graph_store)

    path_retriever = PathRetriever(
        graph_traversal=traversal,
        path_verbalizer=PathVerbalizer(),
    )

    scoring_retriever = None
    if args.scoring:
        from graphrag_toolkit.byokg_rag.graph_retrievers import GraphScoringRetriever, LocalGReranker

        reranker = LocalGReranker(
            model_name="BAAI/bge-reranker-v2-m3",
            topk=10,
            device=args.device,
        )
        scoring_retriever = GraphScoringRetriever(
            graph_traversal=traversal,
            graph_verbalizer=TripletGVerbalizer(),
            graph_reranker=reranker,
        )

    tokenized_metapaths: List[List[str]] = []
    for p in args.metapaths:
        if p == "*" or "->" not in p:
            tokenized_metapaths.append([p])
        else:
            tokenized_metapaths.append([t.strip() for t in p.split("->") if t.strip()])

    workload: List[Tuple[List[str], List[str]]] = []
    for _ in range(args.pairs):
        src = random.choice(nodes)
        ans = random.choice(nodes)
        if src == ans:
            continue
        workload.append(([src], [ans]))
    if not workload:
        raise RuntimeError("Failed to build workload pairs.")

    # -------- Benchmark PathRetriever --------
    for i in range(args.warmup):
        srcs, answers = workload[i % len(workload)]
        path_retriever.retrieve(srcs, tokenized_metapaths, answers)

    durations = []
    total_ctx = 0
    for i in range(args.iters):
        srcs, answers = workload[i % len(workload)]
        dt, ctx = timed(path_retriever.retrieve, srcs, tokenized_metapaths, answers)
        durations.append(dt)
        try:
            total_ctx += len(ctx)
        except Exception:
            pass

    s = compute_stats(durations)
    print("\n=== PathRetriever.retrieve ===")
    print(f"n={s.n}  mean={s.mean_ms:.3f} ms  p50={s.p50_ms:.3f} ms  p95={s.p95_ms:.3f} ms")
    print(f"min={s.min_ms:.3f} ms  max={s.max_ms:.3f} ms  throughput={s.qps:.2f} ops/s")
    print(f"avg_context_items={(total_ctx / s.n):.2f}")

    if scoring_retriever is not None:
        for i in range(args.warmup):
            srcs, _answers = workload[i % len(workload)]
            scoring_retriever.retrieve("benchmark query", srcs)

        durations2 = []
        total_ctx2 = 0
        for i in range(args.iters):
            srcs, _answers = workload[i % len(workload)]
            dt, ctx = timed(scoring_retriever.retrieve, "benchmark query", srcs)
            durations2.append(dt)
            total_ctx2 += len(ctx)

        s2 = compute_stats(durations2)
        print("\n=== GraphScoringRetriever.retrieve ===")
        print(f"n={s2.n}  mean={s2.mean_ms:.3f} ms  p50={s2.p50_ms:.3f} ms  p95={s2.p95_ms:.3f} ms")
        print(f"min={s2.min_ms:.3f} ms  max={s2.max_ms:.3f} ms  throughput={s2.qps:.2f} ops/s")
        print(f"avg_context_items={(total_ctx2 / s2.n):.2f}")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions