Skip to content

Commit 1de9482

Browse files
committed
feat(retrieval): add LLM-based query complexity detection with heuristic fallback
Add comprehensive query complexity detection system that uses LLM classification when available, falling back to heuristic rules. Supports both English and Chinese queries with improved word counting for CJK characters. The complexity detector now accepts an optional LLM client for accurate classification while maintaining backward compatibility with rule-based detection. - Add LLM-based complexity detection using pilot's LLM client - Implement heuristic fallback with enhanced keyword matching - Support Chinese language complexity indicators - Add proper CJK character word counting estimation - Update analyze stage to use LLM-enhanced complexity detection - Create new pilot complexity module with JSON response parsing - Include comprehensive test coverage for both approaches
1 parent 68a10ac commit 1de9482

File tree

7 files changed

+276
-106
lines changed

7 files changed

+276
-106
lines changed

rust/src/retrieval/complexity/detector.rs

Lines changed: 168 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,143 +2,133 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
//! Query complexity detector implementation.
5+
//!
6+
//! Uses Pilot's LLM client for accurate complexity classification when available.
7+
//! Falls back to heuristic rules (keyword + word count) when no LLM client.
58
69
use std::collections::HashSet;
710

811
use super::QueryComplexity;
912

10-
/// Configuration for complexity detection.
11-
#[derive(Debug, Clone)]
12-
pub struct ComplexityConfig {
13-
/// Maximum words for simple query.
14-
pub simple_max_words: usize,
15-
/// Maximum words for medium query.
16-
pub medium_max_words: usize,
17-
/// Complexity indicators (words that suggest complex queries).
18-
pub complex_indicators: Vec<String>,
19-
/// Simple query indicators.
20-
pub simple_indicators: Vec<String>,
21-
}
22-
23-
impl Default for ComplexityConfig {
24-
fn default() -> Self {
25-
Self {
26-
simple_max_words: 5,
27-
medium_max_words: 15,
28-
complex_indicators: vec![
29-
"compare".to_string(),
30-
"contrast".to_string(),
31-
"analyze".to_string(),
32-
"evaluate".to_string(),
33-
"synthesize".to_string(),
34-
"explain why".to_string(),
35-
"how does".to_string(),
36-
"what are the implications".to_string(),
37-
"relationship between".to_string(),
38-
"cause and effect".to_string(),
39-
],
40-
simple_indicators: vec![
41-
"what is".to_string(),
42-
"define".to_string(),
43-
"list".to_string(),
44-
"who".to_string(),
45-
"when".to_string(),
46-
"where".to_string(),
47-
],
48-
}
49-
}
50-
}
51-
5213
/// Query complexity detector.
5314
///
54-
/// Analyzes queries to determine their complexity level,
55-
/// which influences strategy selection.
15+
/// Uses LLM for classification when available; falls back to heuristic rules.
5616
pub struct ComplexityDetector {
57-
config: ComplexityConfig,
17+
/// Optional LLM client for LLM-based detection.
18+
llm_client: Option<crate::llm::LlmClient>,
5819
}
5920

6021
impl ComplexityDetector {
61-
/// Create a new complexity detector.
22+
/// Create a new complexity detector (heuristic only).
6223
pub fn new() -> Self {
24+
Self { llm_client: None }
25+
}
26+
27+
/// Create with LLM client for accurate detection.
28+
pub fn with_llm_client(client: crate::llm::LlmClient) -> Self {
6329
Self {
64-
config: ComplexityConfig::default(),
30+
llm_client: Some(client),
6531
}
6632
}
6733

68-
/// Create with custom configuration.
69-
pub fn with_config(config: ComplexityConfig) -> Self {
70-
Self { config }
34+
/// Detect the complexity of a query.
35+
///
36+
/// Uses LLM when available; falls back to heuristic rules.
37+
pub async fn detect(&self, query: &str) -> QueryComplexity {
38+
if let Some(ref client) = self.llm_client {
39+
if let Some(complexity) = crate::retrieval::pilot::detect_with_llm(client, query).await
40+
{
41+
return complexity;
42+
}
43+
tracing::warn!("LLM complexity detection failed, falling back to heuristic");
44+
}
45+
self.detect_heuristic(query)
7146
}
7247

73-
/// Detect the complexity of a query.
74-
pub fn detect(&self, query: &str) -> QueryComplexity {
48+
/// Heuristic-based fallback: keyword matching + word count.
49+
fn detect_heuristic(&self, query: &str) -> QueryComplexity {
7550
let query_lower = query.to_lowercase();
76-
let word_count = query.split_whitespace().count();
51+
let word_count = estimate_word_count(query);
52+
53+
// Complex indicators (English + Chinese)
54+
let complex_indicators = [
55+
"compare",
56+
"contrast",
57+
"analyze",
58+
"evaluate",
59+
"synthesize",
60+
"explain why",
61+
"how does",
62+
"relationship between",
63+
"cause and effect",
64+
"对比",
65+
"分析",
66+
"评估",
67+
"综合",
68+
"为什么",
69+
"原因",
70+
"关系",
71+
"影响",
72+
"区别",
73+
"异同",
74+
];
7775

78-
// Check for complex indicators
79-
for indicator in &self.config.complex_indicators {
76+
for indicator in &complex_indicators {
8077
if query_lower.contains(indicator) {
8178
return QueryComplexity::Complex;
8279
}
8380
}
8481

85-
// Check for simple indicators
86-
for indicator in &self.config.simple_indicators {
87-
if query_lower.contains(indicator) {
88-
// Simple indicator found, but check word count
89-
if word_count <= self.config.medium_max_words {
90-
return QueryComplexity::Simple;
91-
}
82+
// Simple indicators
83+
let simple_indicators = [
84+
"what is",
85+
"define",
86+
"list",
87+
"who",
88+
"when",
89+
"where",
90+
"什么是",
91+
"定义",
92+
"列表",
93+
"谁",
94+
"何时",
95+
"哪里",
96+
"在哪",
97+
];
98+
99+
for indicator in &simple_indicators {
100+
if query_lower.contains(indicator) && word_count <= 15 {
101+
return QueryComplexity::Simple;
92102
}
93103
}
94104

95-
// Check for multiple questions
96-
let question_marks = query.matches('?').count();
105+
// Multiple questions
106+
let question_marks = query.matches('?').count() + query.matches('?').count();
97107
if question_marks > 1 {
98108
return QueryComplexity::Complex;
99109
}
100110

101-
// Check for conjunctions suggesting multiple parts
102-
let conjunctions = ["and", "or", "but", "however", "although"];
103-
let conjunction_count = conjunctions
104-
.iter()
105-
.filter(|c| query_lower.split_whitespace().any(|w| w == **c))
106-
.count();
107-
108-
if conjunction_count >= 2 {
109-
return QueryComplexity::Complex;
110-
}
111-
112-
// Check for nested concepts
113-
let depth_indicators = ["in the context of", "with respect to", "regarding", "about"];
114-
for indicator in depth_indicators {
115-
if query_lower.contains(indicator) {
116-
return QueryComplexity::Medium;
117-
}
118-
}
119-
120-
// Word count based classification
121-
if word_count <= self.config.simple_max_words {
111+
// Word count classification
112+
if word_count <= 5 {
122113
QueryComplexity::Simple
123-
} else if word_count <= self.config.medium_max_words {
114+
} else if word_count <= 15 {
124115
QueryComplexity::Medium
125116
} else {
126117
QueryComplexity::Complex
127118
}
128119
}
129120

130121
/// Get complexity score (0.0 - 1.0).
131-
pub fn complexity_score(&self, query: &str) -> f32 {
132-
match self.detect(query) {
122+
pub fn complexity_score(&self, complexity: QueryComplexity) -> f32 {
123+
match complexity {
133124
QueryComplexity::Simple => 0.2,
134125
QueryComplexity::Medium => 0.5,
135126
QueryComplexity::Complex => 0.8,
136127
}
137128
}
138129

139-
/// Analyze query features.
130+
/// Analyze query features (heuristic only, no LLM call).
140131
pub fn analyze(&self, query: &str) -> QueryAnalysis {
141-
let query_lower = query.to_lowercase();
142132
let words: Vec<&str> = query.split_whitespace().collect();
143133
let unique_words: HashSet<&str> = words.iter().copied().collect();
144134

@@ -149,10 +139,10 @@ impl ComplexityDetector {
149139
} else {
150140
unique_words.len() as f32 / words.len() as f32
151141
},
152-
has_question_mark: query.contains('?'),
153-
question_count: query.matches('?').count(),
154-
complexity: self.detect(query),
155-
complexity_score: self.complexity_score(query),
142+
has_question_mark: query.contains('?') || query.contains('?'),
143+
question_count: query.matches('?').count() + query.matches('?').count(),
144+
complexity: self.detect_heuristic(query),
145+
complexity_score: self.complexity_score(self.detect_heuristic(query)),
156146
}
157147
}
158148
}
@@ -163,6 +153,52 @@ impl Default for ComplexityDetector {
163153
}
164154
}
165155

156+
/// Estimate word count, handling both CJK and Latin text.
157+
fn estimate_word_count(text: &str) -> usize {
158+
let mut count = 0usize;
159+
let mut in_latin_word = false;
160+
161+
for ch in text.chars() {
162+
if ch.is_whitespace() {
163+
if in_latin_word {
164+
count += 1;
165+
in_latin_word = false;
166+
}
167+
} else if ch.is_ascii_alphanumeric() {
168+
in_latin_word = true;
169+
} else if is_cjk_char(ch) {
170+
if in_latin_word {
171+
count += 1;
172+
in_latin_word = false;
173+
}
174+
count += 1;
175+
} else {
176+
if in_latin_word {
177+
count += 1;
178+
in_latin_word = false;
179+
}
180+
}
181+
}
182+
if in_latin_word {
183+
count += 1;
184+
}
185+
count
186+
}
187+
188+
/// Check if a character is CJK (Chinese/Japanese/Korean).
189+
fn is_cjk_char(ch: char) -> bool {
190+
let cp = ch as u32;
191+
(0x4E00..=0x9FFF).contains(&cp)
192+
|| (0x3400..=0x4DBF).contains(&cp)
193+
|| (0x20000..=0x2A6DF).contains(&cp)
194+
|| (0x2A700..=0x2B73F).contains(&cp)
195+
|| (0xF900..=0xFAFF).contains(&cp)
196+
|| (0x2F800..=0x2FA1F).contains(&cp)
197+
|| (0x3000..=0x303F).contains(&cp)
198+
|| (0x3040..=0x309F).contains(&cp)
199+
|| (0x30A0..=0x30FF).contains(&cp)
200+
}
201+
166202
/// Analysis result for a query.
167203
#[derive(Debug, Clone)]
168204
pub struct QueryAnalysis {
@@ -188,21 +224,40 @@ mod tests {
188224
fn test_simple_queries() {
189225
let detector = ComplexityDetector::new();
190226

191-
assert_eq!(detector.detect("What is Rust?"), QueryComplexity::Simple);
192-
assert_eq!(detector.detect("Define async"), QueryComplexity::Simple);
193-
assert_eq!(detector.detect("List features"), QueryComplexity::Simple);
227+
assert_eq!(
228+
detector.detect_heuristic("What is Rust?"),
229+
QueryComplexity::Simple
230+
);
231+
assert_eq!(
232+
detector.detect_heuristic("Define async"),
233+
QueryComplexity::Simple
234+
);
235+
assert_eq!(
236+
detector.detect_heuristic("什么是向量检索"),
237+
QueryComplexity::Simple
238+
);
194239
}
195240

196241
#[test]
197242
fn test_complex_queries() {
198243
let detector = ComplexityDetector::new();
199244

200245
assert_eq!(
201-
detector.detect("Compare and contrast the different approaches to async programming"),
246+
detector.detect_heuristic(
247+
"Compare and contrast the different approaches to async programming"
248+
),
249+
QueryComplexity::Complex
250+
);
251+
assert_eq!(
252+
detector.detect_heuristic("What is the relationship between ownership and borrowing?"),
253+
QueryComplexity::Complex
254+
);
255+
assert_eq!(
256+
detector.detect_heuristic("对比A和B的区别"),
202257
QueryComplexity::Complex
203258
);
204259
assert_eq!(
205-
detector.detect("What is the relationship between ownership and borrowing?"),
260+
detector.detect_heuristic("分析索引和检索的关系"),
206261
QueryComplexity::Complex
207262
);
208263
}
@@ -211,8 +266,20 @@ mod tests {
211266
fn test_medium_queries() {
212267
let detector = ComplexityDetector::new();
213268

214-
// Medium length without complex indicators
215269
let medium_query = "How do I implement a simple web server with error handling?";
216-
assert_eq!(detector.detect(medium_query), QueryComplexity::Medium);
270+
assert_eq!(detector.detect_heuristic(medium_query), QueryComplexity::Medium);
271+
}
272+
273+
#[test]
274+
fn test_estimate_word_count() {
275+
assert_eq!(estimate_word_count("hello world"), 2);
276+
assert_eq!(estimate_word_count("什么是向量"), 4);
277+
assert_eq!(estimate_word_count("什么是 vector search"), 4);
278+
}
279+
280+
#[test]
281+
fn test_no_llm_is_ok() {
282+
let detector = ComplexityDetector::new();
283+
assert!(detector.llm_client.is_none());
217284
}
218285
}

0 commit comments

Comments
 (0)