Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion rust/src/metrics/pilot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub enum InterventionPoint {
Backtrack,
/// Evaluating content sufficiency.
Evaluate,
/// Binary pruning for wide nodes.
Prune,
}

/// Helper to store f64 as u64 bits for atomic operations.
Expand Down Expand Up @@ -87,7 +89,7 @@ impl PilotMetrics {
InterventionPoint::Start => {
self.start_guidance_calls.fetch_add(1, Ordering::Relaxed);
}
InterventionPoint::Fork => {
InterventionPoint::Fork | InterventionPoint::Prune => {
self.fork_decisions.fetch_add(1, Ordering::Relaxed);
}
InterventionPoint::Backtrack => {
Expand Down
179 changes: 177 additions & 2 deletions rust/src/retrieval/pilot/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ pub struct PilotConfig {
pub guide_at_backtrack: bool,
/// Optional path to custom prompt templates.
pub prompt_template_path: Option<String>,
/// Pre-filtering configuration for reducing candidates before Pilot.
pub prefilter: PrefilterConfig,
/// Binary pruning configuration for quick relevance filtering.
pub prune: PruneConfig,
}

impl Default for PilotConfig {
Expand All @@ -38,6 +42,8 @@ impl Default for PilotConfig {
guide_at_start: true,
guide_at_backtrack: true,
prompt_template_path: None,
prefilter: PrefilterConfig::default(),
prune: PruneConfig::default(),
}
}
}
Expand All @@ -51,7 +57,7 @@ impl PilotConfig {
}
}

/// Create a high-quality config (more LLM calls).
/// Create a high-quality config (more LLM calls, generous pre-filter).
pub fn high_quality() -> Self {
Self {
mode: PilotMode::Aggressive,
Expand All @@ -71,10 +77,20 @@ impl PilotConfig {
guide_at_start: true,
guide_at_backtrack: true,
prompt_template_path: None,
prefilter: PrefilterConfig {
threshold: 20,
max_to_pilot: 20,
enabled: true,
},
prune: PruneConfig {
enabled: true,
threshold: 25,
min_keep: 5,
},
}
}

/// Create a low-cost config (fewer LLM calls).
/// Create a low-cost config (fewer LLM calls, aggressive pre-filter).
pub fn low_cost() -> Self {
Self {
mode: PilotMode::Conservative,
Expand All @@ -94,13 +110,33 @@ impl PilotConfig {
guide_at_start: false,
guide_at_backtrack: true,
prompt_template_path: None,
prefilter: PrefilterConfig {
threshold: 8,
max_to_pilot: 8,
enabled: true,
},
prune: PruneConfig {
enabled: true,
threshold: 12,
min_keep: 2,
},
}
}

/// Create a pure algorithm config (no LLM calls).
pub fn algorithm_only() -> Self {
Self {
mode: PilotMode::AlgorithmOnly,
prefilter: PrefilterConfig {
threshold: 15,
max_to_pilot: 15,
enabled: false,
},
prune: PruneConfig {
enabled: false,
threshold: 20,
min_keep: 3,
},
..Default::default()
}
}
Expand Down Expand Up @@ -228,6 +264,88 @@ impl InterventionConfig {
}
}

/// Configuration for NodeScorer-based pre-filtering before Pilot scoring.
///
/// When a node has many children, sending all to the LLM is wasteful.
/// Pre-filtering uses cheap NodeScorer (keyword/BM25) to narrow the
/// candidate set before expensive Pilot (LLM) scoring.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefilterConfig {
/// Minimum number of candidates to trigger pre-filtering.
///
/// When `candidates.len()` exceeds this threshold, NodeScorer
/// pre-filters before sending to Pilot.
/// Default: 15.
pub threshold: usize,

/// Maximum number of candidates passed to Pilot after pre-filtering.
///
/// NodeScorer's top-N are kept; the rest get NodeScorer-only scores.
/// Default: 15.
pub max_to_pilot: usize,

/// Whether pre-filtering is enabled.
/// Default: true.
pub enabled: bool,
}

impl Default for PrefilterConfig {
fn default() -> Self {
Self {
threshold: 15,
max_to_pilot: 15,
enabled: true,
}
}
}

impl PrefilterConfig {
/// Check if pre-filtering should be applied given the candidate count.
pub fn should_prefilter(&self, candidate_count: usize) -> bool {
self.enabled && candidate_count > self.threshold
}
}

/// Configuration for binary pruning before full Pilot scoring.
///
/// After P2 pre-filtering, if candidates still exceed this threshold,
/// a lightweight LLM call asks "which are relevant?" before the full
/// scoring call. This reduces the number of candidates that receive
/// expensive detailed scoring.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruneConfig {
/// Whether binary pruning is enabled.
/// Default: true.
pub enabled: bool,

/// Trigger threshold — binary prune activates when the candidate
/// count (after P2 pre-filtering) exceeds this value.
/// Default: 20.
pub threshold: usize,

/// Minimum candidates to keep after pruning, even if LLM says
/// fewer are relevant. Prevents over-aggressive pruning.
/// Default: 3.
pub min_keep: usize,
}

impl Default for PruneConfig {
fn default() -> Self {
Self {
enabled: true,
threshold: 20,
min_keep: 3,
}
}
}

impl PruneConfig {
/// Check if binary pruning should be applied given the candidate count.
pub fn should_prune(&self, candidate_count: usize) -> bool {
self.enabled && candidate_count > self.threshold
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -269,11 +387,68 @@ mod tests {
fn test_pilot_config_presets() {
let high = PilotConfig::high_quality();
assert_eq!(high.mode, PilotMode::Aggressive);
assert!(high.prefilter.enabled);
assert_eq!(high.prefilter.threshold, 20);

let low = PilotConfig::low_cost();
assert_eq!(low.mode, PilotMode::Conservative);
assert!(low.prefilter.enabled);
assert_eq!(low.prefilter.threshold, 8);

let algo = PilotConfig::algorithm_only();
assert_eq!(algo.mode, PilotMode::AlgorithmOnly);
assert!(!algo.prefilter.enabled);
}

#[test]
fn test_prefilter_config_default() {
let cfg = PrefilterConfig::default();
assert!(cfg.enabled);
assert_eq!(cfg.threshold, 15);
assert_eq!(cfg.max_to_pilot, 15);
}

#[test]
fn test_prefilter_should_prefilter() {
let cfg = PrefilterConfig::default();
assert!(!cfg.should_prefilter(15)); // at threshold
assert!(!cfg.should_prefilter(10)); // below
assert!(cfg.should_prefilter(16)); // above

let disabled = PrefilterConfig { enabled: false, ..Default::default() };
assert!(!disabled.should_prefilter(100));
}

#[test]
fn test_prune_config_default() {
let cfg = PruneConfig::default();
assert!(cfg.enabled);
assert_eq!(cfg.threshold, 20);
assert_eq!(cfg.min_keep, 3);
}

#[test]
fn test_prune_should_prune() {
let cfg = PruneConfig::default();
assert!(!cfg.should_prune(20)); // at threshold
assert!(!cfg.should_prune(15)); // below
assert!(cfg.should_prune(21)); // above

let disabled = PruneConfig { enabled: false, ..Default::default() };
assert!(!disabled.should_prune(100));
}

#[test]
fn test_pilot_config_presets_prune() {
let high = PilotConfig::high_quality();
assert!(high.prune.enabled);
assert_eq!(high.prune.threshold, 25);

let low = PilotConfig::low_cost();
assert!(low.prune.enabled);
assert_eq!(low.prune.threshold, 12);

let algo = PilotConfig::algorithm_only();
assert!(!algo.prune.enabled);
}
}
3 changes: 3 additions & 0 deletions rust/src/retrieval/pilot/decision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ pub enum InterventionPoint {
Backtrack,
/// Evaluating a specific node for relevance.
Evaluate,
/// Binary pruning — quick yes/no relevance filter for wide nodes.
Prune,
}

impl InterventionPoint {
Expand All @@ -230,6 +232,7 @@ impl InterventionPoint {
Self::Fork => "fork",
Self::Backtrack => "backtrack",
Self::Evaluate => "evaluate",
Self::Prune => "prune",
}
}
}
Expand Down
82 changes: 79 additions & 3 deletions rust/src/retrieval/pilot/decision_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ pub struct ScoredCandidate {
/// from the Pilot. Use this when the search algorithm needs to
/// record why each path step was taken (e.g., for beam search
/// reasoning history).
///
/// # Pre-filtering
///
/// When a node has many children (exceeding `prefilter.threshold`),
/// NodeScorer pre-filters candidates before sending to Pilot. This
/// reduces LLM token cost and latency. Candidates filtered out still
/// receive NodeScorer-only scores in the final merge, so no results
/// are lost.
pub async fn score_candidates_detailed(
tree: &DocumentTree,
candidates: &[NodeId],
Expand Down Expand Up @@ -139,20 +147,88 @@ pub async fn score_candidates_detailed(
// Determine parent node (last in path) for cache key
let parent = path.last().copied().unwrap_or(tree.root());

// === PRE-FILTERING ===
// When candidates exceed the threshold, use NodeScorer to narrow
// the set before sending to Pilot (LLM). Filtered-out candidates
// still get NodeScorer-only scores in the final merge below.
let prefilter_cfg = &p.config().prefilter;
let pilot_candidates: Vec<NodeId> = if prefilter_cfg.should_prefilter(candidates.len()) {
let scorer = NodeScorer::new(ScoringContext::new(query));
let mut sorted = scorer.score_and_sort(tree, candidates);
let pilot_max = prefilter_cfg.max_to_pilot.min(candidates.len());
sorted.truncate(pilot_max);
let ids: Vec<NodeId> = sorted.into_iter().map(|(id, _)| id).collect();
tracing::debug!(
"Pre-filtered: {} candidates -> {} to Pilot (threshold={})",
candidates.len(),
ids.len(),
prefilter_cfg.threshold,
);
ids
} else {
candidates.to_vec()
};

// === BINARY PRUNING ===
// After P2 pre-filtering, if candidates still exceed the prune
// threshold, ask Pilot for a quick yes/no filter before the
// expensive full-scoring call.
let prune_cfg = &p.config().prune;
let pilot_candidates = if prune_cfg.should_prune(pilot_candidates.len()) {
let mut prune_state = SearchState::new(
tree, query, path, &pilot_candidates, visited,
);
prune_state.step_reasons = step_reasons;

if let Some(relevant_ids) = p.binary_prune(&prune_state).await {
let relevant_set: HashSet<NodeId> = relevant_ids.iter().copied().collect();
let mut pruned: Vec<NodeId> = pilot_candidates
.iter()
.filter(|id| relevant_set.contains(id))
.copied()
.collect();

// Enforce min_keep to prevent over-aggressive pruning
if pruned.len() < prune_cfg.min_keep {
// Fill from the top of pilot_candidates that weren't pruned
for id in &pilot_candidates {
if pruned.len() >= prune_cfg.min_keep {
break;
}
if !relevant_set.contains(id) {
pruned.push(*id);
}
}
}

tracing::debug!(
"Binary prune: {} candidates -> {} relevant (min_keep={})",
pilot_candidates.len(),
pruned.len(),
prune_cfg.min_keep,
);
pruned
} else {
pilot_candidates
}
} else {
pilot_candidates
};

// Check cache first
let decision = if let Some(c) = cache {
if let Some(cached) = c.get(query, parent).await {
tracing::trace!("Pilot cache hit for parent={:?}", parent);
cached
} else {
let mut state = SearchState::new(tree, query, path, candidates, visited);
let mut state = SearchState::new(tree, query, path, &pilot_candidates, visited);
state.step_reasons = step_reasons;
let d = p.decide(&state).await;
c.put(query, parent, &d).await;
d
}
} else {
let mut state = SearchState::new(tree, query, path, candidates, visited);
let mut state = SearchState::new(tree, query, path, &pilot_candidates, visited);
state.step_reasons = step_reasons;
p.decide(&state).await
};
Expand All @@ -163,7 +239,7 @@ pub async fn score_candidates_detailed(
pilot_data.insert(ranked.node_id, (ranked.score, ranked.reason.clone()));
}

// Compute NodeScorer fallback scores
// Compute NodeScorer fallback scores for ALL original candidates
let scorer_weight = 1.0 - pilot_weight;
let confidence = decision.confidence;
let effective_pilot = pilot_weight * confidence;
Expand Down
Loading