22// SPDX-License-Identifier: Apache-2.0
33
44//! Query complexity detector implementation.
5+ //!
6+ //! Uses Pilot's LLM client for accurate complexity classification when available.
7+ //! Falls back to heuristic rules (keyword + word count) when no LLM client.
58
69use std:: collections:: HashSet ;
710
811use super :: QueryComplexity ;
912
10- /// Configuration for complexity detection.
11- #[ derive( Debug , Clone ) ]
12- pub struct ComplexityConfig {
13- /// Maximum words for simple query.
14- pub simple_max_words : usize ,
15- /// Maximum words for medium query.
16- pub medium_max_words : usize ,
17- /// Complexity indicators (words that suggest complex queries).
18- pub complex_indicators : Vec < String > ,
19- /// Simple query indicators.
20- pub simple_indicators : Vec < String > ,
21- }
22-
23- impl Default for ComplexityConfig {
24- fn default ( ) -> Self {
25- Self {
26- simple_max_words : 5 ,
27- medium_max_words : 15 ,
28- complex_indicators : vec ! [
29- "compare" . to_string( ) ,
30- "contrast" . to_string( ) ,
31- "analyze" . to_string( ) ,
32- "evaluate" . to_string( ) ,
33- "synthesize" . to_string( ) ,
34- "explain why" . to_string( ) ,
35- "how does" . to_string( ) ,
36- "what are the implications" . to_string( ) ,
37- "relationship between" . to_string( ) ,
38- "cause and effect" . to_string( ) ,
39- ] ,
40- simple_indicators : vec ! [
41- "what is" . to_string( ) ,
42- "define" . to_string( ) ,
43- "list" . to_string( ) ,
44- "who" . to_string( ) ,
45- "when" . to_string( ) ,
46- "where" . to_string( ) ,
47- ] ,
48- }
49- }
50- }
51-
5213/// Query complexity detector.
5314///
54- /// Analyzes queries to determine their complexity level,
55- /// which influences strategy selection.
15+ /// Uses LLM for classification when available; falls back to heuristic rules.
5616pub struct ComplexityDetector {
57- config : ComplexityConfig ,
17+ /// Optional LLM client for LLM-based detection.
18+ llm_client : Option < crate :: llm:: LlmClient > ,
5819}
5920
6021impl ComplexityDetector {
61- /// Create a new complexity detector.
22+ /// Create a new complexity detector (heuristic only) .
6223 pub fn new ( ) -> Self {
24+ Self { llm_client : None }
25+ }
26+
27+ /// Create with LLM client for accurate detection.
28+ pub fn with_llm_client ( client : crate :: llm:: LlmClient ) -> Self {
6329 Self {
64- config : ComplexityConfig :: default ( ) ,
30+ llm_client : Some ( client ) ,
6531 }
6632 }
6733
68- /// Create with custom configuration.
69- pub fn with_config ( config : ComplexityConfig ) -> Self {
70- Self { config }
34+ /// Detect the complexity of a query.
35+ ///
36+ /// Uses LLM when available; falls back to heuristic rules.
37+ pub async fn detect ( & self , query : & str ) -> QueryComplexity {
38+ if let Some ( ref client) = self . llm_client {
39+ if let Some ( complexity) = crate :: retrieval:: pilot:: detect_with_llm ( client, query) . await
40+ {
41+ return complexity;
42+ }
43+ tracing:: warn!( "LLM complexity detection failed, falling back to heuristic" ) ;
44+ }
45+ self . detect_heuristic ( query)
7146 }
7247
73- /// Detect the complexity of a query .
74- pub fn detect ( & self , query : & str ) -> QueryComplexity {
48+ /// Heuristic-based fallback: keyword matching + word count .
49+ fn detect_heuristic ( & self , query : & str ) -> QueryComplexity {
7550 let query_lower = query. to_lowercase ( ) ;
76- let word_count = query. split_whitespace ( ) . count ( ) ;
51+ let word_count = estimate_word_count ( query) ;
52+
53+ // Complex indicators (English + Chinese)
54+ let complex_indicators = [
55+ "compare" ,
56+ "contrast" ,
57+ "analyze" ,
58+ "evaluate" ,
59+ "synthesize" ,
60+ "explain why" ,
61+ "how does" ,
62+ "relationship between" ,
63+ "cause and effect" ,
64+ "对比" ,
65+ "分析" ,
66+ "评估" ,
67+ "综合" ,
68+ "为什么" ,
69+ "原因" ,
70+ "关系" ,
71+ "影响" ,
72+ "区别" ,
73+ "异同" ,
74+ ] ;
7775
78- // Check for complex indicators
79- for indicator in & self . config . complex_indicators {
76+ for indicator in & complex_indicators {
8077 if query_lower. contains ( indicator) {
8178 return QueryComplexity :: Complex ;
8279 }
8380 }
8481
85- // Check for simple indicators
86- for indicator in & self . config . simple_indicators {
87- if query_lower. contains ( indicator) {
88- // Simple indicator found, but check word count
89- if word_count <= self . config . medium_max_words {
90- return QueryComplexity :: Simple ;
91- }
82+ // Simple indicators
83+ let simple_indicators = [
84+ "what is" ,
85+ "define" ,
86+ "list" ,
87+ "who" ,
88+ "when" ,
89+ "where" ,
90+ "什么是" ,
91+ "定义" ,
92+ "列表" ,
93+ "谁" ,
94+ "何时" ,
95+ "哪里" ,
96+ "在哪" ,
97+ ] ;
98+
99+ for indicator in & simple_indicators {
100+ if query_lower. contains ( indicator) && word_count <= 15 {
101+ return QueryComplexity :: Simple ;
92102 }
93103 }
94104
95- // Check for multiple questions
96- let question_marks = query. matches ( '?' ) . count ( ) ;
105+ // Multiple questions
106+ let question_marks = query. matches ( '?' ) . count ( ) + query . matches ( '?' ) . count ( ) ;
97107 if question_marks > 1 {
98108 return QueryComplexity :: Complex ;
99109 }
100110
101- // Check for conjunctions suggesting multiple parts
102- let conjunctions = [ "and" , "or" , "but" , "however" , "although" ] ;
103- let conjunction_count = conjunctions
104- . iter ( )
105- . filter ( |c| query_lower. split_whitespace ( ) . any ( |w| w == * * c) )
106- . count ( ) ;
107-
108- if conjunction_count >= 2 {
109- return QueryComplexity :: Complex ;
110- }
111-
112- // Check for nested concepts
113- let depth_indicators = [ "in the context of" , "with respect to" , "regarding" , "about" ] ;
114- for indicator in depth_indicators {
115- if query_lower. contains ( indicator) {
116- return QueryComplexity :: Medium ;
117- }
118- }
119-
120- // Word count based classification
121- if word_count <= self . config . simple_max_words {
111+ // Word count classification
112+ if word_count <= 5 {
122113 QueryComplexity :: Simple
123- } else if word_count <= self . config . medium_max_words {
114+ } else if word_count <= 15 {
124115 QueryComplexity :: Medium
125116 } else {
126117 QueryComplexity :: Complex
127118 }
128119 }
129120
130121 /// Get complexity score (0.0 - 1.0).
131- pub fn complexity_score ( & self , query : & str ) -> f32 {
132- match self . detect ( query ) {
122+ pub fn complexity_score ( & self , complexity : QueryComplexity ) -> f32 {
123+ match complexity {
133124 QueryComplexity :: Simple => 0.2 ,
134125 QueryComplexity :: Medium => 0.5 ,
135126 QueryComplexity :: Complex => 0.8 ,
136127 }
137128 }
138129
139- /// Analyze query features.
130+ /// Analyze query features (heuristic only, no LLM call) .
140131 pub fn analyze ( & self , query : & str ) -> QueryAnalysis {
141- let query_lower = query. to_lowercase ( ) ;
142132 let words: Vec < & str > = query. split_whitespace ( ) . collect ( ) ;
143133 let unique_words: HashSet < & str > = words. iter ( ) . copied ( ) . collect ( ) ;
144134
@@ -149,10 +139,10 @@ impl ComplexityDetector {
149139 } else {
150140 unique_words. len ( ) as f32 / words. len ( ) as f32
151141 } ,
152- has_question_mark : query. contains ( '?' ) ,
153- question_count : query. matches ( '?' ) . count ( ) ,
154- complexity : self . detect ( query) ,
155- complexity_score : self . complexity_score ( query) ,
142+ has_question_mark : query. contains ( '?' ) || query . contains ( '?' ) ,
143+ question_count : query. matches ( '?' ) . count ( ) + query . matches ( '?' ) . count ( ) ,
144+ complexity : self . detect_heuristic ( query) ,
145+ complexity_score : self . complexity_score ( self . detect_heuristic ( query) ) ,
156146 }
157147 }
158148}
@@ -163,6 +153,52 @@ impl Default for ComplexityDetector {
163153 }
164154}
165155
156+ /// Estimate word count, handling both CJK and Latin text.
157+ fn estimate_word_count ( text : & str ) -> usize {
158+ let mut count = 0usize ;
159+ let mut in_latin_word = false ;
160+
161+ for ch in text. chars ( ) {
162+ if ch. is_whitespace ( ) {
163+ if in_latin_word {
164+ count += 1 ;
165+ in_latin_word = false ;
166+ }
167+ } else if ch. is_ascii_alphanumeric ( ) {
168+ in_latin_word = true ;
169+ } else if is_cjk_char ( ch) {
170+ if in_latin_word {
171+ count += 1 ;
172+ in_latin_word = false ;
173+ }
174+ count += 1 ;
175+ } else {
176+ if in_latin_word {
177+ count += 1 ;
178+ in_latin_word = false ;
179+ }
180+ }
181+ }
182+ if in_latin_word {
183+ count += 1 ;
184+ }
185+ count
186+ }
187+
188+ /// Check if a character is CJK (Chinese/Japanese/Korean).
189+ fn is_cjk_char ( ch : char ) -> bool {
190+ let cp = ch as u32 ;
191+ ( 0x4E00 ..=0x9FFF ) . contains ( & cp)
192+ || ( 0x3400 ..=0x4DBF ) . contains ( & cp)
193+ || ( 0x20000 ..=0x2A6DF ) . contains ( & cp)
194+ || ( 0x2A700 ..=0x2B73F ) . contains ( & cp)
195+ || ( 0xF900 ..=0xFAFF ) . contains ( & cp)
196+ || ( 0x2F800 ..=0x2FA1F ) . contains ( & cp)
197+ || ( 0x3000 ..=0x303F ) . contains ( & cp)
198+ || ( 0x3040 ..=0x309F ) . contains ( & cp)
199+ || ( 0x30A0 ..=0x30FF ) . contains ( & cp)
200+ }
201+
166202/// Analysis result for a query.
167203#[ derive( Debug , Clone ) ]
168204pub struct QueryAnalysis {
@@ -188,21 +224,40 @@ mod tests {
188224 fn test_simple_queries ( ) {
189225 let detector = ComplexityDetector :: new ( ) ;
190226
191- assert_eq ! ( detector. detect( "What is Rust?" ) , QueryComplexity :: Simple ) ;
192- assert_eq ! ( detector. detect( "Define async" ) , QueryComplexity :: Simple ) ;
193- assert_eq ! ( detector. detect( "List features" ) , QueryComplexity :: Simple ) ;
227+ assert_eq ! (
228+ detector. detect_heuristic( "What is Rust?" ) ,
229+ QueryComplexity :: Simple
230+ ) ;
231+ assert_eq ! (
232+ detector. detect_heuristic( "Define async" ) ,
233+ QueryComplexity :: Simple
234+ ) ;
235+ assert_eq ! (
236+ detector. detect_heuristic( "什么是向量检索" ) ,
237+ QueryComplexity :: Simple
238+ ) ;
194239 }
195240
196241 #[ test]
197242 fn test_complex_queries ( ) {
198243 let detector = ComplexityDetector :: new ( ) ;
199244
200245 assert_eq ! (
201- detector. detect( "Compare and contrast the different approaches to async programming" ) ,
246+ detector. detect_heuristic(
247+ "Compare and contrast the different approaches to async programming"
248+ ) ,
249+ QueryComplexity :: Complex
250+ ) ;
251+ assert_eq ! (
252+ detector. detect_heuristic( "What is the relationship between ownership and borrowing?" ) ,
253+ QueryComplexity :: Complex
254+ ) ;
255+ assert_eq ! (
256+ detector. detect_heuristic( "对比A和B的区别" ) ,
202257 QueryComplexity :: Complex
203258 ) ;
204259 assert_eq ! (
205- detector. detect ( "What is the relationship between ownership and borrowing? ") ,
260+ detector. detect_heuristic ( "分析索引和检索的关系 ") ,
206261 QueryComplexity :: Complex
207262 ) ;
208263 }
@@ -211,8 +266,20 @@ mod tests {
211266 fn test_medium_queries ( ) {
212267 let detector = ComplexityDetector :: new ( ) ;
213268
214- // Medium length without complex indicators
215269 let medium_query = "How do I implement a simple web server with error handling?" ;
216- assert_eq ! ( detector. detect( medium_query) , QueryComplexity :: Medium ) ;
270+ assert_eq ! ( detector. detect_heuristic( medium_query) , QueryComplexity :: Medium ) ;
271+ }
272+
273+ #[ test]
274+ fn test_estimate_word_count ( ) {
275+ assert_eq ! ( estimate_word_count( "hello world" ) , 2 ) ;
276+ assert_eq ! ( estimate_word_count( "什么是向量" ) , 4 ) ;
277+ assert_eq ! ( estimate_word_count( "什么是 vector search" ) , 4 ) ;
278+ }
279+
280+ #[ test]
281+ fn test_no_llm_is_ok ( ) {
282+ let detector = ComplexityDetector :: new ( ) ;
283+ assert ! ( detector. llm_client. is_none( ) ) ;
217284 }
218285}
0 commit comments