Skip to content

Commit 76f9054

Browse files
committed
refactor(rust): improve code formatting and readability across retrieval modules
- Format function parameters and struct initializations with proper line breaks - Reorder imports for better consistency across search modules - Clean up function signature formatting in complexity.rs - Improve test case formatting in config.rs for better readability - Refactor complex expressions into multi-line format for clarity - Move trait module import after scorer module for logical grouping - Update method calls to use consistent formatting style
1 parent b5e2120 commit 76f9054

File tree

10 files changed

+131
-59
lines changed

10 files changed

+131
-59
lines changed

rust/src/retrieval/pilot/complexity.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ const USER_PROMPT: &str = include_str!("prompts/user_complexity.txt");
2525
/// Detect query complexity using LLM.
2626
///
2727
/// Returns `None` if the LLM call fails (caller should fall back to heuristic).
28-
pub async fn detect_with_llm(
29-
client: &LlmClient,
30-
query: &str,
31-
) -> Option<QueryComplexity> {
28+
pub async fn detect_with_llm(client: &LlmClient, query: &str) -> Option<QueryComplexity> {
3229
let user = USER_PROMPT.replace("{query}", query);
3330

3431
let resp: ComplexityResponse = client

rust/src/retrieval/pilot/config.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,12 @@ mod tests {
413413
let cfg = PrefilterConfig::default();
414414
assert!(!cfg.should_prefilter(15)); // at threshold
415415
assert!(!cfg.should_prefilter(10)); // below
416-
assert!(cfg.should_prefilter(16)); // above
416+
assert!(cfg.should_prefilter(16)); // above
417417

418-
let disabled = PrefilterConfig { enabled: false, ..Default::default() };
418+
let disabled = PrefilterConfig {
419+
enabled: false,
420+
..Default::default()
421+
};
419422
assert!(!disabled.should_prefilter(100));
420423
}
421424

@@ -432,9 +435,12 @@ mod tests {
432435
let cfg = PruneConfig::default();
433436
assert!(!cfg.should_prune(20)); // at threshold
434437
assert!(!cfg.should_prune(15)); // below
435-
assert!(cfg.should_prune(21)); // above
438+
assert!(cfg.should_prune(21)); // above
436439

437-
let disabled = PruneConfig { enabled: false, ..Default::default() };
440+
let disabled = PruneConfig {
441+
enabled: false,
442+
..Default::default()
443+
};
438444
assert!(!disabled.should_prune(100));
439445
}
440446

rust/src/retrieval/pilot/decision_scorer.rs

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,15 @@ pub async fn score_candidates(
8989
step_reasons: Option<&[Option<String>]>,
9090
) -> Vec<(NodeId, f32)> {
9191
let scored = score_candidates_detailed(
92-
tree, candidates, query, pilot, path, visited, pilot_weight, cache, step_reasons,
92+
tree,
93+
candidates,
94+
query,
95+
pilot,
96+
path,
97+
visited,
98+
pilot_weight,
99+
cache,
100+
step_reasons,
93101
)
94102
.await;
95103
scored.into_iter().map(|s| (s.node_id, s.score)).collect()
@@ -175,9 +183,7 @@ pub async fn score_candidates_detailed(
175183
// expensive full-scoring call.
176184
let prune_cfg = &p.config().prune;
177185
let pilot_candidates = if prune_cfg.should_prune(pilot_candidates.len()) {
178-
let mut prune_state = SearchState::new(
179-
tree, query, path, &pilot_candidates, visited,
180-
);
186+
let mut prune_state = SearchState::new(tree, query, path, &pilot_candidates, visited);
181187
prune_state.step_reasons = step_reasons;
182188

183189
if let Some(relevant_ids) = p.binary_prune(&prune_state).await {
@@ -250,7 +256,8 @@ pub async fn score_candidates_detailed(
250256
.iter()
251257
.map(|&node_id| {
252258
let algo_score = scorer.score(tree, node_id);
253-
let (p_score, reason) = pilot_data.get(&node_id)
259+
let (p_score, reason) = pilot_data
260+
.get(&node_id)
254261
.map(|(s, r)| (*s, r.clone()))
255262
.unwrap_or((0.0, None));
256263

@@ -261,11 +268,19 @@ pub async fn score_candidates_detailed(
261268
algo_score
262269
};
263270

264-
ScoredCandidate { node_id, score: final_score, reason }
271+
ScoredCandidate {
272+
node_id,
273+
score: final_score,
274+
reason,
275+
}
265276
})
266277
.collect();
267278

268-
scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
279+
scored.sort_by(|a, b| {
280+
b.score
281+
.partial_cmp(&a.score)
282+
.unwrap_or(std::cmp::Ordering::Equal)
283+
});
269284
scored
270285
}
271286

@@ -289,7 +304,11 @@ fn score_with_scorer_detailed(
289304
scorer
290305
.score_and_sort(tree, candidates)
291306
.into_iter()
292-
.map(|(node_id, score)| ScoredCandidate { node_id, score, reason: None })
307+
.map(|(node_id, score)| ScoredCandidate {
308+
node_id,
309+
score,
310+
reason: None,
311+
})
293312
.collect()
294313
}
295314

rust/src/retrieval/pilot/llm_pilot.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -713,11 +713,14 @@ impl Pilot for LlmPilot {
713713
.iter()
714714
.enumerate()
715715
.filter_map(|(i, &node_id)| {
716-
state.tree.get(node_id).map(|node| super::parser::CandidateInfo {
717-
node_id,
718-
title: node.title.clone(),
719-
index: i,
720-
})
716+
state
717+
.tree
718+
.get(node_id)
719+
.map(|node| super::parser::CandidateInfo {
720+
node_id,
721+
title: node.title.clone(),
722+
index: i,
723+
})
721724
})
722725
.collect();
723726

rust/src/retrieval/pilot/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ mod metrics;
4343
mod noop;
4444
mod parser;
4545
mod prompts;
46-
mod r#trait;
4746
mod scorer;
47+
mod r#trait;
4848

4949
pub use complexity::detect_with_llm;
50-
pub use config::{PilotConfig, PrefilterConfig, PruneConfig};
50+
pub use config::PilotConfig;
5151
pub use decision::{InterventionPoint, PilotDecision};
52-
pub use decision_scorer::{PilotDecisionCache, ScoredCandidate, score_candidates, score_candidates_detailed};
52+
pub use decision_scorer::{PilotDecisionCache, score_candidates, score_candidates_detailed};
5353
pub use llm_pilot::LlmPilot;
54-
pub use r#trait::{Pilot, SearchState};
5554
pub use scorer::{NodeScorer, ScoringContext};
55+
pub use r#trait::{Pilot, SearchState};

rust/src/retrieval/search/beam.rs

Lines changed: 72 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ use tracing::debug;
2020

2121
use super::super::RetrievalContext;
2222
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
23-
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, score_candidates_detailed};
2423
use super::{SearchConfig, SearchResult, SearchTree};
2524
use crate::document::{DocumentTree, NodeId};
2625
use crate::retrieval::pilot::{Pilot, SearchState};
26+
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, score_candidates_detailed};
2727

2828
/// Maximum entries in the fallback stack relative to beam width.
2929
const FALLBACK_STACK_MULTIPLIER: usize = 3;
@@ -91,7 +91,11 @@ impl BeamSearch {
9191
if let Some(min_idx) = fallback_stack
9292
.iter()
9393
.enumerate()
94-
.min_by(|(_, a), (_, b)| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
94+
.min_by(|(_, a), (_, b)| {
95+
a.score
96+
.partial_cmp(&b.score)
97+
.unwrap_or(std::cmp::Ordering::Equal)
98+
})
9599
.map(|(i, _)| i)
96100
{
97101
if entry.score > fallback_stack[min_idx].score {
@@ -114,7 +118,11 @@ impl BeamSearch {
114118
let max_idx = fallback_stack
115119
.iter()
116120
.enumerate()
117-
.max_by(|(_, a), (_, b)| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
121+
.max_by(|(_, a), (_, b)| {
122+
a.score
123+
.partial_cmp(&b.score)
124+
.unwrap_or(std::cmp::Ordering::Equal)
125+
})
118126
.map(|(i, _)| i)?;
119127
Some(fallback_stack.swap_remove(max_idx))
120128
}
@@ -287,11 +295,8 @@ impl BeamSearch {
287295
.unwrap_or(std::cmp::Ordering::Equal)
288296
});
289297

290-
let mut current_beam: Vec<SearchPath> = sorted_initial
291-
.iter()
292-
.take(beam_width)
293-
.cloned()
294-
.collect();
298+
let mut current_beam: Vec<SearchPath> =
299+
sorted_initial.iter().take(beam_width).cloned().collect();
295300

296301
// Remaining candidates go to fallback stack
297302
for path in sorted_initial.iter().skip(beam_width) {
@@ -421,16 +426,14 @@ impl BeamSearch {
421426

422427
// Keep top beam_width in the beam, shelve the rest
423428
let mut beam_candidates = next_beam;
424-
let overflow: Vec<SearchPath> = beam_candidates.split_off(beam_width.min(beam_candidates.len()));
429+
let overflow: Vec<SearchPath> =
430+
beam_candidates.split_off(beam_width.min(beam_candidates.len()));
425431

426432
for path in overflow {
427433
let score = path.score;
428434
Self::push_fallback(
429435
&mut fallback_stack,
430-
FallbackEntry {
431-
path,
432-
score,
433-
},
436+
FallbackEntry { path, score },
434437
config.min_score,
435438
config.fallback_score_ratio,
436439
max_fallback_size,
@@ -566,18 +569,33 @@ mod tests {
566569

567570
BeamSearch::push_fallback(
568571
&mut stack,
569-
FallbackEntry { path: SearchPath::from_node(id0, 0.3), score: 0.3 },
570-
0.1, 0.5, 100,
572+
FallbackEntry {
573+
path: SearchPath::from_node(id0, 0.3),
574+
score: 0.3,
575+
},
576+
0.1,
577+
0.5,
578+
100,
571579
);
572580
BeamSearch::push_fallback(
573581
&mut stack,
574-
FallbackEntry { path: SearchPath::from_node(id1, 0.7), score: 0.7 },
575-
0.1, 0.5, 100,
582+
FallbackEntry {
583+
path: SearchPath::from_node(id1, 0.7),
584+
score: 0.7,
585+
},
586+
0.1,
587+
0.5,
588+
100,
576589
);
577590
BeamSearch::push_fallback(
578591
&mut stack,
579-
FallbackEntry { path: SearchPath::from_node(id2, 0.5), score: 0.5 },
580-
0.1, 0.5, 100,
592+
FallbackEntry {
593+
path: SearchPath::from_node(id2, 0.5),
594+
score: 0.5,
595+
},
596+
0.1,
597+
0.5,
598+
100,
581599
);
582600

583601
assert_eq!(stack.len(), 3);
@@ -603,16 +621,26 @@ mod tests {
603621
// Score 0.01 with threshold 0.1 * 0.5 = 0.05 → should be rejected
604622
BeamSearch::push_fallback(
605623
&mut stack,
606-
FallbackEntry { path: SearchPath::from_node(id0, 0.01), score: 0.01 },
607-
0.1, 0.5, 100,
624+
FallbackEntry {
625+
path: SearchPath::from_node(id0, 0.01),
626+
score: 0.01,
627+
},
628+
0.1,
629+
0.5,
630+
100,
608631
);
609632
assert_eq!(stack.len(), 0, "Score below threshold should be rejected");
610633

611634
// Score 0.06 with threshold 0.05 → should be accepted
612635
BeamSearch::push_fallback(
613636
&mut stack,
614-
FallbackEntry { path: SearchPath::from_node(id1, 0.06), score: 0.06 },
615-
0.1, 0.5, 100,
637+
FallbackEntry {
638+
path: SearchPath::from_node(id1, 0.06),
639+
score: 0.06,
640+
},
641+
0.1,
642+
0.5,
643+
100,
616644
);
617645
assert_eq!(stack.len(), 1, "Score above threshold should be accepted");
618646
}
@@ -628,21 +656,36 @@ mod tests {
628656
// Fill to capacity (max_size=2)
629657
BeamSearch::push_fallback(
630658
&mut stack,
631-
FallbackEntry { path: SearchPath::from_node(id0, 0.3), score: 0.3 },
632-
0.1, 0.5, 2,
659+
FallbackEntry {
660+
path: SearchPath::from_node(id0, 0.3),
661+
score: 0.3,
662+
},
663+
0.1,
664+
0.5,
665+
2,
633666
);
634667
BeamSearch::push_fallback(
635668
&mut stack,
636-
FallbackEntry { path: SearchPath::from_node(id1, 0.5), score: 0.5 },
637-
0.1, 0.5, 2,
669+
FallbackEntry {
670+
path: SearchPath::from_node(id1, 0.5),
671+
score: 0.5,
672+
},
673+
0.1,
674+
0.5,
675+
2,
638676
);
639677
assert_eq!(stack.len(), 2);
640678

641679
// Push a higher-score entry → should evict the lowest (0.3)
642680
BeamSearch::push_fallback(
643681
&mut stack,
644-
FallbackEntry { path: SearchPath::from_node(id2, 0.8), score: 0.8 },
645-
0.1, 0.5, 2,
682+
FallbackEntry {
683+
path: SearchPath::from_node(id2, 0.8),
684+
score: 0.8,
685+
},
686+
0.1,
687+
0.5,
688+
2,
646689
);
647690
assert_eq!(stack.len(), 2);
648691

rust/src/retrieval/search/greedy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ use tracing::debug;
1313

1414
use super::super::RetrievalContext;
1515
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
16-
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates};
1716
use super::{SearchConfig, SearchResult, SearchTree};
1817
use crate::document::{DocumentTree, NodeId};
1918
use crate::retrieval::pilot::Pilot;
19+
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates};
2020

2121
/// Pure Pilot search — Pilot picks the best child at each layer.
2222
///

rust/src/retrieval/search/mcts.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ use tracing::debug;
2020

2121
use super::super::RetrievalContext;
2222
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
23-
use crate::retrieval::pilot::{PilotDecisionCache, score_candidates, NodeScorer, ScoringContext};
2423
use super::{SearchConfig, SearchResult, SearchTree};
2524
use crate::document::{DocumentTree, NodeId};
2625
use crate::retrieval::pilot::Pilot;
26+
use crate::retrieval::pilot::{NodeScorer, PilotDecisionCache, ScoringContext, score_candidates};
2727

2828
/// Statistics for a node in MCTS.
2929
#[derive(Debug, Clone, Default)]

rust/src/retrieval/stages/analyze.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,7 @@ impl AnalyzeStage {
147147
/// Enable query decomposition and LLM-based complexity detection.
148148
pub fn with_llm_client(mut self, client: crate::llm::LlmClient) -> Self {
149149
// Use LLM client for complexity detection
150-
self.complexity_detector =
151-
ComplexityDetector::with_llm_client(client.clone());
150+
self.complexity_detector = ComplexityDetector::with_llm_client(client.clone());
152151
// Also enable query decomposition
153152
if self.query_decomposer.is_none() {
154153
self.query_decomposer =

rust/src/retrieval/types.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,12 @@ impl SearchPath {
660660

661661
/// Extend the path with a new node and a reason for choosing it.
662662
#[must_use]
663-
pub fn extend_with_reason(&self, node_id: NodeId, score: f32, reason: impl Into<String>) -> Self {
663+
pub fn extend_with_reason(
664+
&self,
665+
node_id: NodeId,
666+
score: f32,
667+
reason: impl Into<String>,
668+
) -> Self {
664669
let mut nodes = self.nodes.clone();
665670
let mut step_reasons = self.step_reasons.clone();
666671
nodes.push(node_id);

0 commit comments

Comments
 (0)