Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions rust/src/retrieval/pilot/complexity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ const USER_PROMPT: &str = include_str!("prompts/user_complexity.txt");
/// Detect query complexity using LLM.
///
/// Returns `None` if the LLM call fails (caller should fall back to heuristic).
pub async fn detect_with_llm(
client: &LlmClient,
query: &str,
) -> Option<QueryComplexity> {
pub async fn detect_with_llm(client: &LlmClient, query: &str) -> Option<QueryComplexity> {
let user = USER_PROMPT.replace("{query}", query);

let resp: ComplexityResponse = client
Expand Down
14 changes: 10 additions & 4 deletions rust/src/retrieval/pilot/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,12 @@ mod tests {
let cfg = PrefilterConfig::default();
assert!(!cfg.should_prefilter(15)); // at threshold
assert!(!cfg.should_prefilter(10)); // below
assert!(cfg.should_prefilter(16)); // above
assert!(cfg.should_prefilter(16)); // above

let disabled = PrefilterConfig { enabled: false, ..Default::default() };
let disabled = PrefilterConfig {
enabled: false,
..Default::default()
};
assert!(!disabled.should_prefilter(100));
}

Expand All @@ -432,9 +435,12 @@ mod tests {
let cfg = PruneConfig::default();
assert!(!cfg.should_prune(20)); // at threshold
assert!(!cfg.should_prune(15)); // below
assert!(cfg.should_prune(21)); // above
assert!(cfg.should_prune(21)); // above

let disabled = PruneConfig { enabled: false, ..Default::default() };
let disabled = PruneConfig {
enabled: false,
..Default::default()
};
assert!(!disabled.should_prune(100));
}

Expand Down
35 changes: 27 additions & 8 deletions rust/src/retrieval/pilot/decision_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,15 @@ pub async fn score_candidates(
step_reasons: Option<&[Option<String>]>,
) -> Vec<(NodeId, f32)> {
let scored = score_candidates_detailed(
tree, candidates, query, pilot, path, visited, pilot_weight, cache, step_reasons,
tree,
candidates,
query,
pilot,
path,
visited,
pilot_weight,
cache,
step_reasons,
)
.await;
scored.into_iter().map(|s| (s.node_id, s.score)).collect()
Expand Down Expand Up @@ -175,9 +183,7 @@ pub async fn score_candidates_detailed(
// expensive full-scoring call.
let prune_cfg = &p.config().prune;
let pilot_candidates = if prune_cfg.should_prune(pilot_candidates.len()) {
let mut prune_state = SearchState::new(
tree, query, path, &pilot_candidates, visited,
);
let mut prune_state = SearchState::new(tree, query, path, &pilot_candidates, visited);
prune_state.step_reasons = step_reasons;

if let Some(relevant_ids) = p.binary_prune(&prune_state).await {
Expand Down Expand Up @@ -250,7 +256,8 @@ pub async fn score_candidates_detailed(
.iter()
.map(|&node_id| {
let algo_score = scorer.score(tree, node_id);
let (p_score, reason) = pilot_data.get(&node_id)
let (p_score, reason) = pilot_data
.get(&node_id)
.map(|(s, r)| (*s, r.clone()))
.unwrap_or((0.0, None));

Expand All @@ -261,11 +268,19 @@ pub async fn score_candidates_detailed(
algo_score
};

ScoredCandidate { node_id, score: final_score, reason }
ScoredCandidate {
node_id,
score: final_score,
reason,
}
})
.collect();

scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored
}

Expand All @@ -289,7 +304,11 @@ fn score_with_scorer_detailed(
scorer
.score_and_sort(tree, candidates)
.into_iter()
.map(|(node_id, score)| ScoredCandidate { node_id, score, reason: None })
.map(|(node_id, score)| ScoredCandidate {
node_id,
score,
reason: None,
})
.collect()
}

Expand Down
13 changes: 8 additions & 5 deletions rust/src/retrieval/pilot/llm_pilot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,11 +713,14 @@ impl Pilot for LlmPilot {
.iter()
.enumerate()
.filter_map(|(i, &node_id)| {
state.tree.get(node_id).map(|node| super::parser::CandidateInfo {
node_id,
title: node.title.clone(),
index: i,
})
state
.tree
.get(node_id)
.map(|node| super::parser::CandidateInfo {
node_id,
title: node.title.clone(),
index: i,
})
})
.collect();

Expand Down
8 changes: 4 additions & 4 deletions rust/src/retrieval/pilot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ mod metrics;
mod noop;
mod parser;
mod prompts;
mod r#trait;
mod scorer;
mod r#trait;

pub use complexity::detect_with_llm;
pub use config::{PilotConfig, PrefilterConfig, PruneConfig};
pub use config::PilotConfig;
pub use decision::{InterventionPoint, PilotDecision};
pub use decision_scorer::{PilotDecisionCache, ScoredCandidate, score_candidates, score_candidates_detailed};
pub use decision_scorer::{PilotDecisionCache, score_candidates, score_candidates_detailed};
pub use llm_pilot::LlmPilot;
pub use r#trait::{Pilot, SearchState};
pub use scorer::{NodeScorer, ScoringContext};
pub use r#trait::{Pilot, SearchState};
101 changes: 72 additions & 29 deletions rust/src/retrieval/search/beam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ use tracing::debug;

use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, score_candidates_detailed};
use super::{SearchConfig, SearchResult, SearchTree};
use crate::document::{DocumentTree, NodeId};
use crate::retrieval::pilot::{Pilot, SearchState};
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, score_candidates_detailed};

/// Maximum entries in the fallback stack relative to beam width.
const FALLBACK_STACK_MULTIPLIER: usize = 3;
Expand Down Expand Up @@ -91,7 +91,11 @@ impl BeamSearch {
if let Some(min_idx) = fallback_stack
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
.min_by(|(_, a), (_, b)| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
{
if entry.score > fallback_stack[min_idx].score {
Expand All @@ -114,7 +118,11 @@ impl BeamSearch {
let max_idx = fallback_stack
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
.max_by(|(_, a), (_, b)| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)?;
Some(fallback_stack.swap_remove(max_idx))
}
Expand Down Expand Up @@ -287,11 +295,8 @@ impl BeamSearch {
.unwrap_or(std::cmp::Ordering::Equal)
});

let mut current_beam: Vec<SearchPath> = sorted_initial
.iter()
.take(beam_width)
.cloned()
.collect();
let mut current_beam: Vec<SearchPath> =
sorted_initial.iter().take(beam_width).cloned().collect();

// Remaining candidates go to fallback stack
for path in sorted_initial.iter().skip(beam_width) {
Expand Down Expand Up @@ -421,16 +426,14 @@ impl BeamSearch {

// Keep top beam_width in the beam, shelve the rest
let mut beam_candidates = next_beam;
let overflow: Vec<SearchPath> = beam_candidates.split_off(beam_width.min(beam_candidates.len()));
let overflow: Vec<SearchPath> =
beam_candidates.split_off(beam_width.min(beam_candidates.len()));

for path in overflow {
let score = path.score;
Self::push_fallback(
&mut fallback_stack,
FallbackEntry {
path,
score,
},
FallbackEntry { path, score },
config.min_score,
config.fallback_score_ratio,
max_fallback_size,
Expand Down Expand Up @@ -566,18 +569,33 @@ mod tests {

BeamSearch::push_fallback(
&mut stack,
FallbackEntry { path: SearchPath::from_node(id0, 0.3), score: 0.3 },
0.1, 0.5, 100,
FallbackEntry {
path: SearchPath::from_node(id0, 0.3),
score: 0.3,
},
0.1,
0.5,
100,
);
BeamSearch::push_fallback(
&mut stack,
FallbackEntry { path: SearchPath::from_node(id1, 0.7), score: 0.7 },
0.1, 0.5, 100,
FallbackEntry {
path: SearchPath::from_node(id1, 0.7),
score: 0.7,
},
0.1,
0.5,
100,
);
BeamSearch::push_fallback(
&mut stack,
FallbackEntry { path: SearchPath::from_node(id2, 0.5), score: 0.5 },
0.1, 0.5, 100,
FallbackEntry {
path: SearchPath::from_node(id2, 0.5),
score: 0.5,
},
0.1,
0.5,
100,
);

assert_eq!(stack.len(), 3);
Expand All @@ -603,16 +621,26 @@ mod tests {
// Score 0.01 with threshold 0.1 * 0.5 = 0.05 → should be rejected
BeamSearch::push_fallback(
&mut stack,
FallbackEntry { path: SearchPath::from_node(id0, 0.01), score: 0.01 },
0.1, 0.5, 100,
FallbackEntry {
path: SearchPath::from_node(id0, 0.01),
score: 0.01,
},
0.1,
0.5,
100,
);
assert_eq!(stack.len(), 0, "Score below threshold should be rejected");

// Score 0.06 with threshold 0.05 → should be accepted
BeamSearch::push_fallback(
&mut stack,
FallbackEntry { path: SearchPath::from_node(id1, 0.06), score: 0.06 },
0.1, 0.5, 100,
FallbackEntry {
path: SearchPath::from_node(id1, 0.06),
score: 0.06,
},
0.1,
0.5,
100,
);
assert_eq!(stack.len(), 1, "Score above threshold should be accepted");
}
Expand All @@ -628,21 +656,36 @@ mod tests {
// Fill to capacity (max_size=2)
BeamSearch::push_fallback(
&mut stack,
FallbackEntry { path: SearchPath::from_node(id0, 0.3), score: 0.3 },
0.1, 0.5, 2,
FallbackEntry {
path: SearchPath::from_node(id0, 0.3),
score: 0.3,
},
0.1,
0.5,
2,
);
BeamSearch::push_fallback(
&mut stack,
FallbackEntry { path: SearchPath::from_node(id1, 0.5), score: 0.5 },
0.1, 0.5, 2,
FallbackEntry {
path: SearchPath::from_node(id1, 0.5),
score: 0.5,
},
0.1,
0.5,
2,
);
assert_eq!(stack.len(), 2);

// Push a higher-score entry → should evict the lowest (0.3)
BeamSearch::push_fallback(
&mut stack,
FallbackEntry { path: SearchPath::from_node(id2, 0.8), score: 0.8 },
0.1, 0.5, 2,
FallbackEntry {
path: SearchPath::from_node(id2, 0.8),
score: 0.8,
},
0.1,
0.5,
2,
);
assert_eq!(stack.len(), 2);

Expand Down
2 changes: 1 addition & 1 deletion rust/src/retrieval/search/greedy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use tracing::debug;

use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates};
use super::{SearchConfig, SearchResult, SearchTree};
use crate::document::{DocumentTree, NodeId};
use crate::retrieval::pilot::Pilot;
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates};

/// Pure Pilot search — Pilot picks the best child at each layer.
///
Expand Down
2 changes: 1 addition & 1 deletion rust/src/retrieval/search/mcts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ use tracing::debug;

use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, NodeScorer, ScoringContext};
use super::{SearchConfig, SearchResult, SearchTree};
use crate::document::{DocumentTree, NodeId};
use crate::retrieval::pilot::Pilot;
use crate::retrieval::pilot::{NodeScorer, PilotDecisionCache, ScoringContext, score_candidates};

/// Statistics for a node in MCTS.
#[derive(Debug, Clone, Default)]
Expand Down
3 changes: 1 addition & 2 deletions rust/src/retrieval/stages/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ impl AnalyzeStage {
/// Enable query decomposition and LLM-based complexity detection.
pub fn with_llm_client(mut self, client: crate::llm::LlmClient) -> Self {
// Use LLM client for complexity detection
self.complexity_detector =
ComplexityDetector::with_llm_client(client.clone());
self.complexity_detector = ComplexityDetector::with_llm_client(client.clone());
// Also enable query decomposition
if self.query_decomposer.is_none() {
self.query_decomposer =
Expand Down
7 changes: 6 additions & 1 deletion rust/src/retrieval/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,12 @@ impl SearchPath {

/// Extend the path with a new node and a reason for choosing it.
#[must_use]
pub fn extend_with_reason(&self, node_id: NodeId, score: f32, reason: impl Into<String>) -> Self {
pub fn extend_with_reason(
&self,
node_id: NodeId,
score: f32,
reason: impl Into<String>,
) -> Self {
let mut nodes = self.nodes.clone();
let mut step_reasons = self.step_reasons.clone();
nodes.push(node_id);
Expand Down