Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 8 additions & 27 deletions docs/design/logo-horizontal.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
107 changes: 58 additions & 49 deletions rust/src/retrieval/search/beam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,89 +105,79 @@ impl BeamSearch {

merged
}
}

impl Default for BeamSearch {
fn default() -> Self {
Self::new()
}
}

#[async_trait]
impl SearchTree for BeamSearch {
async fn search(
/// Core beam search logic parameterized by start node.
///
/// This is the shared implementation used by both `search` (starts from root)
/// and `search_from` (starts from an arbitrary node).
async fn search_impl(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
let mut result = SearchResult::default();
let beam_width = config.beam_width.min(self.beam_width);
let mut visited: HashSet<NodeId> = HashSet::new();

println!("[DEBUG] BeamSearch: query='{}', beam_width={}, min_score={:.2}",
context.query, beam_width, config.min_score);
// Mark start_node as visited so we don't go back up
visited.insert(start_node);

debug!(
"BeamSearch: query='{}', start_node={:?}, beam_width={}, min_score={:.2}",
context.query, start_node, beam_width, config.min_score
);

// Track Pilot interventions
let mut pilot_interventions = 0;

// Initialize with root's children
let root_children = tree.children(tree.root());
println!("[DEBUG] Root has {} children", root_children.len());
// Initialize with start_node's children
let start_children = tree.children(start_node);
debug!("Start node has {} children", start_children.len());

// Check if Pilot wants to guide the start
let initial_candidates = if let Some(p) = pilot {
println!("[DEBUG] BeamSearch: Pilot is available, name={}, guide_at_start={}",
p.name(), p.config().guide_at_start);
debug!(
"BeamSearch: Pilot is available, name={}, guide_at_start={}",
p.name(),
p.config().guide_at_start
);
if p.config().guide_at_start {
println!("[DEBUG] BeamSearch: Calling pilot.guide_start()...");
if let Some(guidance) = p.guide_start(tree, &context.query).await {
debug!(
"Pilot provided start guidance with confidence {}",
guidance.confidence
);
pilot_interventions += 1;
println!("[DEBUG] BeamSearch: Pilot returned guidance! confidence={:.2}, candidates={}",
guidance.confidence, guidance.ranked_candidates.len());

// Use Pilot's ranked order if available
if guidance.has_candidates() {
self.merge_with_pilot_decision(
tree,
&root_children,
&start_children,
&guidance,
&context.query,
)
} else {
println!("[DEBUG] BeamSearch: Guidance has no candidates, using algorithm scoring");
self.score_candidates_with_query(tree, &root_children, &context.query)
self.score_candidates_with_query(tree, &start_children, &context.query)
}
} else {
println!("[DEBUG] BeamSearch: pilot.guide_start() returned None");
self.score_candidates_with_query(tree, &root_children, &context.query)
self.score_candidates_with_query(tree, &start_children, &context.query)
}
} else {
println!("[DEBUG] BeamSearch: guide_at_start=false, skipping Pilot");
self.score_candidates_with_query(tree, &root_children, &context.query)
self.score_candidates_with_query(tree, &start_children, &context.query)
}
} else {
println!("[DEBUG] BeamSearch: No Pilot available");
self.score_candidates_with_query(tree, &root_children, &context.query)
self.score_candidates_with_query(tree, &start_children, &context.query)
};

let mut current_beam: Vec<SearchPath> = initial_candidates
.into_iter()
.map(|(node_id, score)| SearchPath::from_node(node_id, score))
.collect();

// Debug: show initial scores
println!("[DEBUG] Initial {} candidates after scoring", current_beam.len());
for (i, path) in current_beam.iter().enumerate().take(5) {
if let Some(node) = tree.get(path.leaf.unwrap_or(tree.root())) {
println!("[DEBUG] Initial {}: score={:.3}, title='{}'", i, path.score, node.title);
}
}
debug!("Initial {} candidates after scoring", current_beam.len());

// Keep top beam_width
current_beam.truncate(beam_width);
Expand All @@ -207,7 +197,6 @@ impl SearchTree for BeamSearch {

// Check if this is a leaf node
if tree.is_leaf(leaf_id) {
// Add to final results
if path.score >= config.min_score {
result.paths.push(path.clone());
}
Expand All @@ -220,7 +209,6 @@ impl SearchTree for BeamSearch {

// ========== Pilot Intervention Point ==========
let scored_children = if let Some(p) = pilot {
// Build search state for Pilot
let state = SearchState::new(
tree,
&context.query,
Expand All @@ -229,14 +217,12 @@ impl SearchTree for BeamSearch {
&visited,
);

// Check if Pilot wants to intervene
if p.should_intervene(&state) {
trace!(
"Pilot intervening at fork with {} candidates",
children.len()
);

println!("[DEBUG] BEAM SEARCH: Pilot intervening at decision point");
match p.decide(&state).await {
decision => {
pilot_interventions += 1;
Expand All @@ -246,7 +232,6 @@ impl SearchTree for BeamSearch {
std::mem::discriminant(&decision.direction)
);

// Merge algorithm scores with Pilot decision
self.merge_with_pilot_decision(
tree,
&children,
Expand All @@ -256,19 +241,16 @@ impl SearchTree for BeamSearch {
}
}
} else {
// No intervention, use algorithm scoring
self.score_candidates_with_query(tree, &children, &context.query)
}
} else {
// No Pilot, use algorithm scoring
self.score_candidates_with_query(tree, &children, &context.query)
};
// ==============================================

for (child_id, child_score) in scored_children.into_iter().take(beam_width) {
let new_path = path.extend(child_id, child_score);

// Record trace
let child_node = tree.get(child_id);
result.trace.push(NavigationStep {
node_id: format!("{:?}", child_id),
Expand Down Expand Up @@ -296,7 +278,6 @@ impl SearchTree for BeamSearch {

current_beam = next_beam;

// Check if we have enough results
if result.paths.len() >= config.top_k {
break;
}
Expand All @@ -312,9 +293,8 @@ impl SearchTree for BeamSearch {
// Fallback: if no results found, add best candidates regardless of score
if result.paths.is_empty() && config.min_score > 0.0 {
debug!("No results above min_score, adding best candidates as fallback");
// Re-score initial candidates and take top-k
let all_candidates =
self.score_candidates_with_query(tree, &tree.children(tree.root()), &context.query);
self.score_candidates_with_query(tree, &tree.children(start_node), &context.query);
for (node_id, score) in all_candidates.into_iter().take(config.top_k) {
result.paths.push(SearchPath::from_node(node_id, score));
}
Expand All @@ -328,11 +308,40 @@ impl SearchTree for BeamSearch {
});
result.paths.truncate(config.top_k);

// Record Pilot interventions
result.pilot_interventions = pilot_interventions;

result
}
}

impl Default for BeamSearch {
fn default() -> Self {
Self::new()
}
}

#[async_trait]
impl SearchTree for BeamSearch {
async fn search(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, tree.root()).await
}

async fn search_from(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, start_node).await
}

fn name(&self) -> &'static str {
"beam"
Expand Down
60 changes: 39 additions & 21 deletions rust/src/retrieval/search/greedy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ impl GreedySearch {
let algo_score = scorer.score(tree, node_id);
let pilot_score = pilot_scores.get(&node_id).copied().unwrap_or(0.0);

// Weighted combination
let final_score = if beta > 0.0 {
(alpha * algo_score + beta * pilot_score) / (alpha + beta)
} else {
Expand All @@ -81,33 +80,30 @@ impl GreedySearch {
})
.collect();

// Sort by merged score
merged.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

merged
}
}

impl Default for GreedySearch {
fn default() -> Self {
Self::new()
}
}

#[async_trait]
impl SearchTree for GreedySearch {
async fn search(
/// Core greedy search logic parameterized by start node.
async fn search_impl(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
let mut result = SearchResult::default();
let mut current_path = SearchPath::new();
let mut current_node = tree.root();
let mut current_node = start_node;
let mut visited: std::collections::HashSet<NodeId> = std::collections::HashSet::new();

debug!(
"GreedySearch: query='{}', start_node={:?}, max_iterations={}, min_score={:.2}",
context.query, start_node, config.max_iterations, config.min_score
);

// Track Pilot interventions
let mut pilot_interventions = 0;

Expand All @@ -128,7 +124,6 @@ impl SearchTree for GreedySearch {

// ========== Pilot Integration Point ==========
let scored_children = if let Some(p) = pilot {
// Build search state for Pilot
let state = SearchState::new(
tree,
&context.query,
Expand All @@ -137,14 +132,12 @@ impl SearchTree for GreedySearch {
&visited,
);

// Check if Pilot wants to intervene
if p.should_intervene(&state) {
trace!(
"Pilot intervening at greedy decision point with {} candidates",
children.len()
);

println!("[DEBUG] GREEDY SEARCH: Pilot intervening at decision point");
match p.decide(&state).await {
decision => {
pilot_interventions += 1;
Expand All @@ -154,7 +147,6 @@ impl SearchTree for GreedySearch {
std::mem::discriminant(&decision.direction)
);

// Merge algorithm scores with Pilot decision
self.merge_with_pilot_decision(
tree,
&children,
Expand All @@ -164,11 +156,9 @@ impl SearchTree for GreedySearch {
}
}
} else {
// No intervention, use algorithm scoring
self.score_candidates_with_query(tree, &children, &context.query)
}
} else {
// No Pilot, use algorithm scoring
self.score_candidates_with_query(tree, &children, &context.query)
};
// ==============================================
Expand Down Expand Up @@ -205,7 +195,6 @@ impl SearchTree for GreedySearch {
current_node = child_id;
result.nodes_visited += 1;

// Check if we have enough results
if result.paths.len() >= config.top_k {
break;
}
Expand All @@ -219,11 +208,40 @@ impl SearchTree for GreedySearch {
}
}

// Record Pilot interventions
result.pilot_interventions = pilot_interventions;

result
}
}

impl Default for GreedySearch {
fn default() -> Self {
Self::new()
}
}

#[async_trait]
impl SearchTree for GreedySearch {
async fn search(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, tree.root()).await
}

async fn search_from(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, start_node).await
}

fn name(&self) -> &'static str {
"greedy"
Expand Down
2 changes: 2 additions & 0 deletions rust/src/retrieval/search/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod bm25;
mod greedy;
mod mcts;
mod scorer;
mod toc_navigator;
mod r#trait;

pub use beam::BeamSearch;
Expand All @@ -18,4 +19,5 @@ pub use bm25::{
pub use greedy::GreedySearch;
pub use mcts::MctsSearch;
pub use scorer::{NodeScorer, ScoringContext};
pub use toc_navigator::{SearchCue, ToCNavigator};
pub use r#trait::{SearchConfig, SearchResult, SearchTree};
Loading