-
-
Notifications
You must be signed in to change notification settings - Fork 672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Levenshtein distance to score documents in fuzzy term queries #998
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
use crate::common::BitSet; | ||
use crate::core::SegmentReader; | ||
use crate::query::ConstScorer; | ||
use crate::query::{BitSetDocSet, Explanation}; | ||
use crate::query::fuzzy_query::DFAWrapper; | ||
use crate::query::score_combiner::SumCombiner; | ||
use crate::query::Explanation; | ||
use crate::query::{ConstScorer, Union}; | ||
use crate::query::{Scorer, Weight}; | ||
use crate::schema::{Field, IndexRecordOption}; | ||
use crate::termdict::{TermDictionary, TermStreamer}; | ||
use crate::termdict::{TermDictionary, TermWithStateStreamer}; | ||
use crate::TantivyError; | ||
use crate::{DocId, Score}; | ||
use std::any::{Any, TypeId}; | ||
use std::io; | ||
use std::sync::Arc; | ||
use tantivy_fst::Automaton; | ||
|
@@ -33,9 +35,9 @@ where | |
fn automaton_stream<'a>( | ||
&'a self, | ||
term_dict: &'a TermDictionary, | ||
) -> io::Result<TermStreamer<'a, &'a A>> { | ||
) -> io::Result<TermWithStateStreamer<'a, &'a A>> { | ||
let automaton: &A = &*self.automaton; | ||
let term_stream_builder = term_dict.search(automaton); | ||
let term_stream_builder = term_dict.search_with_state(automaton); | ||
term_stream_builder.into_stream() | ||
} | ||
} | ||
|
@@ -46,35 +48,27 @@ where | |
A::State: Clone, | ||
{ | ||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> { | ||
let max_doc = reader.max_doc(); | ||
let mut doc_bitset = BitSet::with_max_value(max_doc); | ||
let inverted_index = reader.inverted_index(self.field)?; | ||
let term_dict = inverted_index.terms(); | ||
let mut term_stream = self.automaton_stream(term_dict)?; | ||
while term_stream.advance() { | ||
let term_info = term_stream.value(); | ||
let mut block_segment_postings = inverted_index | ||
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?; | ||
loop { | ||
let docs = block_segment_postings.docs(); | ||
if docs.is_empty() { | ||
break; | ||
} | ||
for &doc in docs { | ||
doc_bitset.insert(doc); | ||
} | ||
block_segment_postings.advance(); | ||
} | ||
|
||
let mut scorers = vec![]; | ||
while let Some((_term, term_info, state)) = term_stream.next() { | ||
let score = automaton_score(self.automaton.as_ref(), state); | ||
let segment_postings = | ||
inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?; | ||
let scorer = ConstScorer::new(segment_postings, boost * score); | ||
scorers.push(scorer); | ||
} | ||
let doc_bitset = BitSetDocSet::from(doc_bitset); | ||
let const_scorer = ConstScorer::new(doc_bitset, boost); | ||
Ok(Box::new(const_scorer)) | ||
|
||
let scorer = Union::<_, SumCombiner>::from(scorers); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this sum up the score of all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I understand it, yes. So if a doc contains There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I did not understand this method correctly then. I though it would return a scorer for multiple different documents, and in this case sum the score of different documents. |
||
Ok(Box::new(scorer)) | ||
} | ||
|
||
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { | ||
let mut scorer = self.scorer(reader, 1.0)?; | ||
if scorer.seek(doc) == doc { | ||
Ok(Explanation::new("AutomatonScorer", 1.0)) | ||
Ok(Explanation::new("AutomatonScorer", scorer.score())) | ||
} else { | ||
Err(TantivyError::InvalidArgument( | ||
"Document does not exist".to_string(), | ||
|
@@ -83,6 +77,25 @@ where | |
} | ||
} | ||
|
||
fn automaton_score<A>(automaton: &A, state: A::State) -> f32 | ||
where | ||
A: Automaton + Send + Sync + 'static, | ||
A::State: Clone, | ||
{ | ||
if TypeId::of::<DFAWrapper>() == automaton.type_id() && TypeId::of::<u32>() == state.type_id() { | ||
let dfa = automaton as *const A as *const DFAWrapper; | ||
let dfa = unsafe { &*dfa }; | ||
|
||
let id = &state as *const A::State as *const u32; | ||
let id = unsafe { *id }; | ||
|
||
let dist = dfa.0.distance(id).to_u8() as f32; | ||
1.0 / (1.0 + dist) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we document this score somewhere in the doc of the FuzzyTermScorer? Ideally I think this should be implemented as an array I'm ok with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to be sure, isn't max distance levenshtein 2 ? In this range There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or the transposition below multiplies it by 2 again ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes you are right! |
||
} else { | ||
1.0 | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::AutomatonWeight; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ use super::TermDictionary; | |
use crate::postings::TermInfo; | ||
use crate::termdict::TermOrdinal; | ||
use tantivy_fst::automaton::AlwaysMatch; | ||
use tantivy_fst::map::{Stream, StreamBuilder}; | ||
use tantivy_fst::map::{Stream, StreamBuilder, StreamWithState}; | ||
use tantivy_fst::Automaton; | ||
use tantivy_fst::{IntoStreamer, Streamer}; | ||
|
||
|
@@ -149,3 +149,153 @@ where | |
} | ||
} | ||
} | ||
|
||
/// `TermWithStateStreamerBuilder` is a helper object used to define | ||
/// a range of terms that should be streamed. | ||
pub struct TermWithStateStreamerBuilder<'a, A = AlwaysMatch> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is ok to just have a TermWithStateStreamer and avoid the code duplication. Most of the time the state is small and "Copy". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So is it OK if I remove the extra There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
where | ||
A: Automaton, | ||
A::State: Clone, | ||
{ | ||
fst_map: &'a TermDictionary, | ||
stream_builder: StreamBuilder<'a, A>, | ||
} | ||
|
||
impl<'a, A> TermWithStateStreamerBuilder<'a, A> | ||
where | ||
A: Automaton, | ||
A::State: Clone, | ||
{ | ||
pub(crate) fn new(fst_map: &'a TermDictionary, stream_builder: StreamBuilder<'a, A>) -> Self { | ||
TermWithStateStreamerBuilder { | ||
fst_map, | ||
stream_builder, | ||
} | ||
} | ||
|
||
/// Limit the range to terms greater or equal to the bound | ||
pub fn ge<T: AsRef<[u8]>>(mut self, bound: T) -> Self { | ||
self.stream_builder = self.stream_builder.ge(bound); | ||
self | ||
} | ||
|
||
/// Limit the range to terms strictly greater than the bound | ||
pub fn gt<T: AsRef<[u8]>>(mut self, bound: T) -> Self { | ||
self.stream_builder = self.stream_builder.gt(bound); | ||
self | ||
} | ||
|
||
/// Limit the range to terms lesser or equal to the bound | ||
pub fn le<T: AsRef<[u8]>>(mut self, bound: T) -> Self { | ||
self.stream_builder = self.stream_builder.le(bound); | ||
self | ||
} | ||
|
||
/// Limit the range to terms lesser or equal to the bound | ||
pub fn lt<T: AsRef<[u8]>>(mut self, bound: T) -> Self { | ||
self.stream_builder = self.stream_builder.lt(bound); | ||
self | ||
} | ||
|
||
/// Iterate over the range backwards. | ||
pub fn backward(mut self) -> Self { | ||
self.stream_builder = self.stream_builder.backward(); | ||
self | ||
} | ||
|
||
/// Creates the stream corresponding to the range | ||
/// of terms defined using the `TermWithStateStreamerBuilder`. | ||
pub fn into_stream(self) -> io::Result<TermWithStateStreamer<'a, A>> { | ||
Ok(TermWithStateStreamer { | ||
fst_map: self.fst_map, | ||
stream: self.stream_builder.with_state().into_stream(), | ||
term_ord: 0u64, | ||
current_key: Vec::with_capacity(100), | ||
current_value: TermInfo::default(), | ||
current_state: None, | ||
}) | ||
} | ||
} | ||
|
||
/// `TermWithStateStreamer` acts as a cursor over a range of terms of a segment. | ||
/// Terms are guaranteed to be sorted. | ||
pub struct TermWithStateStreamer<'a, A = AlwaysMatch> | ||
where | ||
A: Automaton, | ||
A::State: Clone, | ||
{ | ||
fst_map: &'a TermDictionary, | ||
stream: StreamWithState<'a, A>, | ||
term_ord: TermOrdinal, | ||
current_key: Vec<u8>, | ||
current_value: TermInfo, | ||
current_state: Option<A::State>, | ||
} | ||
|
||
impl<'a, A> TermWithStateStreamer<'a, A> | ||
where | ||
A: Automaton, | ||
A::State: Clone, | ||
{ | ||
/// Advance position the stream on the next item. | ||
/// Before the first call to `.advance()`, the stream | ||
/// is an unitialized state. | ||
pub fn advance(&mut self) -> bool { | ||
if let Some((term, term_ord, state)) = self.stream.next() { | ||
self.current_key.clear(); | ||
self.current_key.extend_from_slice(term); | ||
self.term_ord = term_ord; | ||
self.current_value = self.fst_map.term_info_from_ord(term_ord); | ||
self.current_state = Some(state); | ||
true | ||
} else { | ||
false | ||
} | ||
} | ||
|
||
/// Returns the `TermOrdinal` of the given term. | ||
/// | ||
/// May panic if the called as `.advance()` as never | ||
/// been called before. | ||
pub fn term_ord(&self) -> TermOrdinal { | ||
self.term_ord | ||
} | ||
|
||
/// Accesses the current key. | ||
/// | ||
/// `.key()` should return the key that was returned | ||
/// by the `.next()` method. | ||
/// | ||
/// If the end of the stream as been reached, and `.next()` | ||
/// has been called and returned `None`, `.key()` remains | ||
/// the value of the last key encountered. | ||
/// | ||
/// Before any call to `.next()`, `.key()` returns an empty array. | ||
pub fn key(&self) -> &[u8] { | ||
&self.current_key | ||
} | ||
|
||
/// Accesses the current value. | ||
/// | ||
/// Calling `.value()` after the end of the stream will return the | ||
/// last `.value()` encountered. | ||
/// | ||
/// # Panics | ||
/// | ||
/// Calling `.value()` before the first call to `.advance()` returns | ||
/// `V::default()`. | ||
pub fn value(&self) -> &TermInfo { | ||
&self.current_value | ||
} | ||
|
||
/// Return the next `(key, value, state)` triplet. | ||
#[cfg_attr(feature = "cargo-clippy", allow(clippy::should_implement_trait))] | ||
pub fn next(&mut self) -> Option<(&[u8], &TermInfo, A::State)> { | ||
if self.advance() { | ||
let state = self.current_state.take().unwrap(); // always Some(_) after advance | ||
Some((self.key(), self.value(), state)) | ||
} else { | ||
None | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is impractical.
It may end up being a list of way too many scorers.
You should probably use a TAAT (term at a time) strategy here.
Let's say you define you know how to compute the score associate to one term match, as function of its doc_freq, term_freq and most importantly the distance (note that the doc_freq alone is not necessarily the best indication...
If I am at levenshtein 1 of two words... The one with the highest document frequency is probably the one that I was shooting for).
You probably want the score for the fuzzytermquery to be the max of the term score of all of the term that were found in the document. Not the sum.
So you could keep the bitset that was there before.
In addition it, you could open a
Vec<Score>
with a len equal to max doc and intialize it to 0.As you go throught he terms, you then update the score of the doc that match.
score[doc] = max(score[doc], termscore(lev_distance, term_frequency, doc_frequency, ...))
Once the Bitset and the Vector of score has been populated, you can return a Scorer that iterates through the bitset and returns the computed scores. You can for instance implement a wrapper around the BitSetDocSet...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AutomatonWeight is used by more than the fuzzy term Query.
For the regex query for isntance, this extra vec with scores is too big an overhead. Please find a solution to avoid the extra cost.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding was that the
doc_id
were not particularly between0
andmax_doc
, so I didn't consider using a score vector likescores[doc as usize]
, was I wrong or should it be aHashMap
instead of aVec
?As for the scoring formula, I see how to get
doc_freq
fromterm_info
for the current segment but I still need to sum those for across all segments, right ? And as we have only access to a specificSegmentReader
here, how do we get the others without aSearcher
?Moreover, to get the
term_freq
I still need theSegmentPostings
so I don't see how to conciliate that with theBitset
/BitSetDocSet
previous approach. All this made me use the simple1 / (1 + dist)
scoring instead.Finally, to avoid impacting regex queries I think we can get back to the previous implementation for the generic case, and just have a specialized
impl Weigth for Automaton<DFAWrapper>
as @maufl suggested.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are guaranteed that doc_id < max_doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scoring part seems very difficult. I commented on the issue.
I'd like us to take a step back and think longer on what is the spec we want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's my two stabs at the score problem.
I think it might be useful to approach the question of "how do you adjust scores in fuzzy searching" by asking "how likely was it that the initial query term was really an incorrect version of another term". If you knew the probability of an error you could weight the score of each fuzzy term
score_adj=score*P(T|Q)
whereP(T|Q)
is "the probability that term T is being searched for given query Q". Just multiplyingP(T|Q)
has issues (floating point error, few high probability events tend to dominate, etc) so-log(P(T_i))
orI(T_i)
(the Shanon information) should be used instead. This leaves us withscore_ajusted=score*(1/I(T_i))
(I'm not entirely sure this is correct). This leaves the question of how do you getP(T_i)
orI(T_i)
. Long story short, I(T_i) can approximated by levenshtein distance. This leaves use atscore_adj=score*(1/(α + dist))
where α is I(Q) (the information given that an initial query withlevenshtein distance=0
is correct). It just so happens thatscore_adj=score*1 / (1 + dist)
is valid here!α=1
Implies that the probability a Lev dist of 0 is 10%,α=.04575749056
would correspond to a 90% probability.Another approach would be to try and fix the scoring function itself. BM25 uses inverse document frequency to compensate for words that rarely show up in documents. One interpretation is that IDF function is a stand in for the amount of information a document containing a term conveys. As explained above LV dist can represent the information that a search being incorrect conveys. To incorperate Lev dist into BM25, it might be useful to investigate the effect of subtracting LV dist from the LDF:
score(D,Q) = Sum_i_n( (IDF(q_i) - LevDist(q_i)) * (f(q_i,D)*(k + 1))/(f(q_i, D) +k*( 1- b + b(|D|/avgdl)))
This would leave the score of dist=0 terms unchanged and would lower scores of results with higher Lev dist.