Skip to content

Commit 385f965

Browse files
committed
enhance: support use lindera tag filter (milvus-io#40416)
relate: milvus-io#39659 Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
1 parent ef00957 commit 385f965

File tree

1 file changed

+253
-35
lines changed

1 file changed

+253
-35
lines changed

internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/lindera_tokenizer.rs

Lines changed: 253 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,31 @@
1-
21
use core::result::Result::Err;
3-
use log::warn;
2+
use std::collections::HashSet;
43

4+
use lindera::dictionary::{load_dictionary_from_kind, DictionaryKind};
55
use lindera::mode::Mode;
66
use lindera::segmenter::Segmenter;
77
use lindera::token::Token as LToken;
8-
use lindera::tokenizer::{Tokenizer as LTokenizer, TokenizerBuilder};
9-
use lindera::dictionary::{load_dictionary_from_kind, DictionaryKind};
10-
use tantivy::tokenizer::{Token, Tokenizer, TokenStream};
8+
use lindera::tokenizer::Tokenizer as LTokenizer;
9+
use tantivy::tokenizer::{Token, TokenStream, Tokenizer};
1110

11+
use lindera::token_filter::japanese_compound_word::JapaneseCompoundWordTokenFilter;
12+
use lindera::token_filter::japanese_keep_tags::JapaneseKeepTagsTokenFilter;
13+
use lindera::token_filter::japanese_stop_tags::JapaneseStopTagsTokenFilter;
14+
use lindera::token_filter::korean_keep_tags::KoreanKeepTagsTokenFilter;
15+
use lindera::token_filter::korean_stop_tags::KoreanStopTagsTokenFilter;
16+
use lindera::token_filter::BoxTokenFilter as LTokenFilter;
17+
18+
use crate::error::{Result, TantivyBindingError};
1219
use serde_json as json;
13-
use crate::error::{Result,TantivyBindingError};
1420

1521
pub struct LinderaTokenStream<'a> {
1622
pub tokens: Vec<LToken<'a>>,
1723
pub token: &'a mut Token,
1824
}
1925

26+
const DICTKINDKEY: &str = "dict_kind";
27+
const FILTERKEY: &str = "filter";
28+
2029
impl<'a> TokenStream for LinderaTokenStream<'a> {
2130
fn advance(&mut self) -> bool {
2231
if self.tokens.is_empty() {
@@ -49,17 +58,25 @@ pub struct LinderaTokenizer {
4958

5059
impl LinderaTokenizer {
5160
/// Create a new `LinderaTokenizer`.
52-
/// This function will create a new `LinderaTokenizer` with settings from the YAML file specified in the `LINDERA_CONFIG_PATH` environment variable.
61+
/// This function will create a new `LinderaTokenizer` with json parameters.
5362
pub fn from_json(params: &json::Map<String, json::Value>) -> Result<LinderaTokenizer> {
5463
let kind = fetch_lindera_kind(params)?;
55-
let dictionary = load_dictionary_from_kind(kind);
56-
if dictionary.is_err(){
57-
return Err(TantivyBindingError::InvalidArgument(format!(
64+
let dictionary = load_dictionary_from_kind(kind.clone()).map_err(|_| {
65+
TantivyBindingError::InvalidArgument(format!(
5866
"lindera tokenizer with invalid dict_kind"
59-
)));
67+
))
68+
})?;
69+
70+
let segmenter = Segmenter::new(Mode::Normal, dictionary, None);
71+
let mut tokenizer = LinderaTokenizer::from_segmenter(segmenter);
72+
73+
// append lindera filter
74+
let filters = fetch_lindera_token_filters(&kind, params)?;
75+
for filter in filters {
76+
tokenizer.append_token_filter(filter)
6077
}
61-
let segmenter = Segmenter::new(Mode::Normal, dictionary.unwrap(), None);
62-
Ok(LinderaTokenizer::from_segmenter(segmenter))
78+
79+
Ok(tokenizer)
6380
}
6481

6582
/// Create a new `LinderaTokenizer`.
@@ -70,6 +87,10 @@ impl LinderaTokenizer {
7087
token: Default::default(),
7188
}
7289
}
90+
91+
pub fn append_token_filter(&mut self, filter: LTokenFilter) {
92+
self.tokenizer.append_token_filter(filter);
93+
}
7394
}
7495

7596
impl Tokenizer for LinderaTokenizer {
@@ -88,9 +109,9 @@ trait DictionaryKindParser {
88109
fn into_dict_kind(self) -> Result<DictionaryKind>;
89110
}
90111

91-
impl DictionaryKindParser for &str{
112+
impl DictionaryKindParser for &str {
92113
fn into_dict_kind(self) -> Result<DictionaryKind> {
93-
match self{
114+
match self {
94115
"ipadic" => Ok(DictionaryKind::IPADIC),
95116
"ipadic-neologd" => Ok(DictionaryKind::IPADICNEologd),
96117
"unidic" => Ok(DictionaryKind::UniDic),
@@ -99,59 +120,256 @@ impl DictionaryKindParser for &str{
99120
other => Err(TantivyBindingError::InvalidArgument(format!(
100121
"unsupported lindera dict type: {}",
101122
other
102-
)))
123+
))),
103124
}
104125
}
105126
}
106127

107-
fn fetch_lindera_kind(params:&json::Map<String, json::Value>) -> Result<DictionaryKind>{
108-
match params.get("dict_kind"){
109-
Some(val) => {
110-
if !val.is_string(){
111-
return Err(TantivyBindingError::InvalidArgument(format!(
112-
"lindera tokenizer dict kind should be string"
113-
)))
128+
fn fetch_lindera_kind(params: &json::Map<String, json::Value>) -> Result<DictionaryKind> {
129+
params
130+
.get(DICTKINDKEY)
131+
.ok_or_else(|| {
132+
TantivyBindingError::InvalidArgument(format!("lindera tokenizer dict_kind must be set"))
133+
})?
134+
.as_str()
135+
.ok_or_else(|| {
136+
TantivyBindingError::InvalidArgument(format!(
137+
"lindera tokenizer dict kind should be string"
138+
))
139+
})?
140+
.into_dict_kind()
141+
}
142+
143+
fn fetch_lindera_tags_from_params(
144+
params: &json::Map<String, json::Value>,
145+
) -> Result<HashSet<String>> {
146+
params
147+
.get("tags")
148+
.ok_or_else(|| {
149+
TantivyBindingError::InvalidArgument(format!(
150+
"lindera japanese stop tag filter tags must be set"
151+
))
152+
})?
153+
.as_array()
154+
.ok_or_else(|| {
155+
TantivyBindingError::InvalidArgument(format!(
156+
"lindera japanese stop tags filter tags must be array"
157+
))
158+
})?
159+
.iter()
160+
.map(|v| {
161+
v.as_str()
162+
.ok_or_else(|| {
163+
TantivyBindingError::InvalidArgument(format!(
164+
"lindera japanese stop tags filter tags must be string"
165+
))
166+
})
167+
.map(|s| s.to_string())
168+
})
169+
.collect::<Result<HashSet<String>>>()
170+
}
171+
172+
fn fetch_japanese_compound_word_token_filter(
173+
kind: &DictionaryKind,
174+
params: Option<&json::Map<String, json::Value>>,
175+
) -> Result<LTokenFilter> {
176+
let filter_param = params.ok_or_else(|| {
177+
TantivyBindingError::InvalidArgument(format!(
178+
"lindera japanese compound word filter must use with params"
179+
))
180+
})?;
181+
182+
let tags: HashSet<String> = fetch_lindera_tags_from_params(filter_param)?;
183+
184+
let new_tag: Option<String> = filter_param
185+
.get("new_tag")
186+
.map(|v| {
187+
v.as_str()
188+
.ok_or_else(|| {
189+
TantivyBindingError::InvalidArgument(format!(
190+
"lindera japanese compound word filter new_tag must be string"
191+
))
192+
})
193+
.map(|s| s.to_string())
194+
})
195+
.transpose()?;
196+
Ok(JapaneseCompoundWordTokenFilter::new(kind.clone(), tags, new_tag).into())
197+
}
198+
199+
fn fetch_japanese_keep_tags_token_filter(
200+
params: Option<&json::Map<String, json::Value>>,
201+
) -> Result<LTokenFilter> {
202+
Ok(
203+
JapaneseKeepTagsTokenFilter::new(fetch_lindera_tags_from_params(params.ok_or_else(
204+
|| {
205+
TantivyBindingError::InvalidArgument(format!(
206+
"lindera japanese keep tags filter must use with params"
207+
))
208+
},
209+
)?)?)
210+
.into(),
211+
)
212+
}
213+
214+
fn fetch_japanese_stop_tags_token_filter(
215+
params: Option<&json::Map<String, json::Value>>,
216+
) -> Result<LTokenFilter> {
217+
Ok(
218+
JapaneseStopTagsTokenFilter::new(fetch_lindera_tags_from_params(params.ok_or_else(
219+
|| {
220+
TantivyBindingError::InvalidArgument(format!(
221+
"lindera japanese stop tags filter must use with params"
222+
))
223+
},
224+
)?)?)
225+
.into(),
226+
)
227+
}
228+
229+
fn fetch_korean_keep_tags_token_filter(
230+
params: Option<&json::Map<String, json::Value>>,
231+
) -> Result<LTokenFilter> {
232+
Ok(
233+
KoreanKeepTagsTokenFilter::new(fetch_lindera_tags_from_params(params.ok_or_else(
234+
|| {
235+
TantivyBindingError::InvalidArgument(format!(
236+
"lindera korean keep tags filter must use with params"
237+
))
238+
},
239+
)?)?)
240+
.into(),
241+
)
242+
}
243+
244+
fn fetch_korean_stop_tags_token_filter(
245+
params: Option<&json::Map<String, json::Value>>,
246+
) -> Result<LTokenFilter> {
247+
Ok(
248+
KoreanStopTagsTokenFilter::new(fetch_lindera_tags_from_params(params.ok_or_else(
249+
|| {
250+
TantivyBindingError::InvalidArgument(format!(
251+
"lindera korean stop tags filter must use with params"
252+
))
253+
},
254+
)?)?)
255+
.into(),
256+
)
257+
}
258+
259+
fn fetch_lindera_token_filter_params(
260+
params: &json::Value,
261+
) -> Result<(&str, Option<&json::Map<String, json::Value>>)> {
262+
if params.is_string() {
263+
return Ok((params.as_str().unwrap(), None));
264+
}
265+
266+
let kind = params
267+
.as_object()
268+
.ok_or_else(|| {
269+
TantivyBindingError::InvalidArgument(format!(
270+
"lindera tokenizer filter params must be object"
271+
))
272+
})?
273+
.get("kind")
274+
.ok_or_else(|| {
275+
TantivyBindingError::InvalidArgument(format!("lindera tokenizer filter must have type"))
276+
})?
277+
.as_str()
278+
.ok_or_else(|| {
279+
TantivyBindingError::InvalidArgument(format!(
280+
"lindera tokenizer filter type should be string"
281+
))
282+
})?;
283+
284+
Ok((kind, Some(params.as_object().unwrap())))
285+
}
286+
287+
fn fetch_lindera_token_filter(
288+
type_name: &str,
289+
kind: &DictionaryKind,
290+
params: Option<&json::Map<String, json::Value>>,
291+
) -> Result<LTokenFilter> {
292+
match type_name {
293+
"japanese_compound_word" => fetch_japanese_compound_word_token_filter(kind, params),
294+
"japanese_keep_tags" => fetch_japanese_keep_tags_token_filter(params),
295+
"japanese_stop_tags" => fetch_japanese_stop_tags_token_filter(params),
296+
"korean_keep_tags" => fetch_korean_keep_tags_token_filter(params),
297+
"korean_stop_tags" => fetch_korean_stop_tags_token_filter(params),
298+
_ => Err(TantivyBindingError::InvalidArgument(format!(
299+
"unknown lindera filter type"
300+
))),
301+
}
302+
}
303+
304+
fn fetch_lindera_token_filters(
305+
kind: &DictionaryKind,
306+
params: &json::Map<String, json::Value>,
307+
) -> Result<Vec<LTokenFilter>> {
308+
let mut result: Vec<LTokenFilter> = vec![];
309+
310+
match params.get(FILTERKEY) {
311+
Some(v) => {
312+
let filter_list = v.as_array().ok_or_else(|| {
313+
TantivyBindingError::InvalidArgument(format!("lindera filters should be array"))
314+
})?;
315+
316+
for filter_params in filter_list {
317+
let (name, params) = fetch_lindera_token_filter_params(filter_params)?;
318+
let filter = fetch_lindera_token_filter(name, kind, params)?;
319+
result.push(filter);
114320
}
115-
val.as_str().unwrap().into_dict_kind()
116-
},
117-
_ => {
118-
return Err(TantivyBindingError::InvalidArgument(format!(
119-
"lindera tokenizer dict_kind must be set"
120-
)))
121321
}
322+
_ => {}
122323
}
324+
325+
Ok(result)
123326
}
124327

125328
#[cfg(test)]
126329
mod tests {
127330
use serde_json as json;
331+
use tantivy::tokenizer::Tokenizer;
128332

129333
use crate::analyzer::tokenizers::lindera_tokenizer::LinderaTokenizer;
130334

131335
#[test]
132-
fn test_lindera_tokenizer(){
336+
fn test_lindera_tokenizer() {
133337
let params = r#"{
134338
"type": "lindera",
135-
"dict_kind": "ipadic"
339+
"dict_kind": "ipadic",
340+
"filter": [{
341+
"kind": "japanese_stop_tags",
342+
"tags": ["接続詞", "助詞", "助詞,格助詞", "助詞,連体化"]
343+
}]
136344
}"#;
137345
let json_param = json::from_str::<json::Map<String, json::Value>>(&params);
138346
assert!(json_param.is_ok());
139-
347+
140348
let tokenizer = LinderaTokenizer::from_json(&json_param.unwrap());
141349
assert!(tokenizer.is_ok(), "error: {}", tokenizer.err().unwrap());
350+
351+
let mut binding = tokenizer.unwrap();
352+
let stream =
353+
binding.token_stream("東京スカイツリーの最寄り駅はとうきょうスカイツリー駅です");
354+
let mut results = Vec::<String>::new();
355+
for token in stream.tokens {
356+
results.push(token.text.to_string());
357+
}
358+
359+
print!("test tokens :{:?}\n", results)
142360
}
143361

144362
#[test]
145363
#[cfg(feature = "lindera-cc-cedict")]
146-
fn test_lindera_tokenizer_cc(){
364+
fn test_lindera_tokenizer_cc() {
147365
let params = r#"{
148366
"type": "lindera",
149367
"dict_kind": "cc-cedict"
150368
}"#;
151369
let json_param = json::from_str::<json::Map<String, json::Value>>(&params);
152370
assert!(json_param.is_ok());
153-
371+
154372
let tokenizer = LinderaTokenizer::from_json(&json_param.unwrap());
155373
assert!(tokenizer.is_ok(), "error: {}", tokenizer.err().unwrap());
156374
}
157-
}
375+
}

0 commit comments

Comments
 (0)