Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Levenshtein distance to score documents in fuzzy term queries #998

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions src/query/automaton_weight.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::common::BitSet;
use crate::core::SegmentReader;
use crate::query::ConstScorer;
use crate::query::{BitSetDocSet, Explanation};
use crate::query::fuzzy_query::DFAWrapper;
use crate::query::score_combiner::SumCombiner;
use crate::query::Explanation;
use crate::query::{ConstScorer, Union};
use crate::query::{Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::termdict::{TermDictionary, TermWithStateStreamer};
use crate::TantivyError;
use crate::{DocId, Score};
use std::any::{Any, TypeId};
use std::io;
use std::sync::Arc;
use tantivy_fst::Automaton;
Expand All @@ -33,9 +35,9 @@ where
fn automaton_stream<'a>(
&'a self,
term_dict: &'a TermDictionary,
) -> io::Result<TermStreamer<'a, &'a A>> {
) -> io::Result<TermWithStateStreamer<'a, &'a A>> {
let automaton: &A = &*self.automaton;
let term_stream_builder = term_dict.search(automaton);
let term_stream_builder = term_dict.search_with_state(automaton);
term_stream_builder.into_stream()
}
}
Expand All @@ -46,35 +48,27 @@ where
A::State: Clone,
{
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field)?;
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict)?;
while term_stream.advance() {
let term_info = term_stream.value();
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
for &doc in docs {
doc_bitset.insert(doc);
}
block_segment_postings.advance();
}

let mut scorers = vec![];
while let Some((_term, term_info, state)) = term_stream.next() {
let score = automaton_score(self.automaton.as_ref(), state);
let segment_postings =
inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
let scorer = ConstScorer::new(segment_postings, boost * score);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is impractical.
It may end up being a list of way too many scorers.

You should probably use a TAAT (term at a time) strategy here.
Let's say you define you know how to compute the score associate to one term match, as function of its doc_freq, term_freq and most importantly the distance (note that the doc_freq alone is not necessarily the best indication...
If I am at levenshtein 1 of two words... The one with the highest document frequency is probably the one that I was shooting for).

You probably want the score for the fuzzytermquery to be the max of the term score of all of the term that were found in the document. Not the sum.

So you could keep the bitset that was there before.
In addition it, you could open a Vec<Score> with a len equal to max doc and intialize it to 0.

As you go throught he terms, you then update the score of the doc that match.

score[doc] = max(score[doc], termscore(lev_distance, term_frequency, doc_frequency, ...))

Once the Bitset and the Vector of score has been populated, you can return a Scorer that iterates through the bitset and returns the computed scores. You can for instance implement a wrapper around the BitSetDocSet...

struct BitSetDocSetWithPrecomputedScore {
   underlying_bitset: BitSetDocSet
   scores: Vec<Score>
}

impl DocSet {
....
}

impl Scorer {
   fn score(&self) -> Score {
     let doc = self.doc();
      self.scores[doc as usize]
   }
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AutomatonWeight is used by more than the fuzzy term Query.

For the regex query for isntance, this extra vec with scores is too big an overhead. Please find a solution to avoid the extra cost.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that the doc_id were not particularly between 0 and max_doc, so I didn't consider using a score vector like scores[doc as usize], was I wrong or should it be a HashMap instead of a Vec ?

As for the scoring formula, I see how to get doc_freq from term_info for the current segment but I still need to sum those for across all segments, right ? And as we have only access to a specific SegmentReader here, how do we get the others without a Searcher ?

Moreover, to get the term_freq I still need the SegmentPostings so I don't see how to conciliate that with the Bitset/BitSetDocSet previous approach. All this made me use the simple 1 / (1 + dist) scoring instead.

Finally, to avoid impacting regex queries I think we can get back to the previous implementation for the generic case, and just have a specialized impl Weigth for Automaton<DFAWrapper> as @maufl suggested.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that the doc_id were not particularly between 0 and max_doc

You are guaranteed that doc_id < max_doc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scoring part seems very difficult. I commented on the issue.

I'd like us to take a step back and think longer on what is the spec we want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my two stabs at the score problem.

I think it might be useful to approach the question of "how do you adjust scores in fuzzy searching" by asking "how likely was it that the initial query term was really an incorrect version of another term". If you knew the probability of an error you could weight the score of each fuzzy term score_adj=score*P(T|Q) where P(T|Q) is "the probability that term T is being searched for given query Q". Just multiplying P(T|Q) has issues (floating point error, few high probability events tend to dominate, etc) so -log(P(T_i)) or I(T_i) (the Shanon information) should be used instead. This leaves us with score_ajusted=score*(1/I(T_i)) (I'm not entirely sure this is correct). This leaves the question of how do you get P(T_i) or I(T_i). Long story short, I(T_i) can approximated by levenshtein distance. This leaves use at score_adj=score*(1/(α + dist)) where α is I(Q) (the information given that an initial query with levenshtein distance=0 is correct). It just so happens that score_adj=score*1 / (1 + dist) is valid here! α=1 Implies that the probability a Lev dist of 0 is 10%, α=.04575749056 would correspond to a 90% probability.

Another approach would be to try and fix the scoring function itself. BM25 uses inverse document frequency to compensate for words that rarely show up in documents. One interpretation is that IDF function is a stand in for the amount of information a document containing a term conveys. As explained above LV dist can represent the information that a search being incorrect conveys. To incorperate Lev dist into BM25, it might be useful to investigate the effect of subtracting LV dist from the LDF:

score(D,Q) = Sum_i_n( (IDF(q_i) - LevDist(q_i)) * (f(q_i,D)*(k + 1))/(f(q_i, D) +k*( 1- b + b(|D|/avgdl)))

This would leave the score of dist=0 terms unchanged and would lower scores of results with higher Lev dist.

scorers.push(scorer);
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
let const_scorer = ConstScorer::new(doc_bitset, boost);
Ok(Box::new(const_scorer))

let scorer = Union::<_, SumCombiner>::from(scorers);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this sum up the score of all ConstScorers and report the sum as score for each document?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, yes. So if a doc contains japan and japin and we search for japon, then its final score will be 0.5 + 0.5 = 1.0

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I did not understand this method correctly then. I though it would return a scorer for multiple different documents, and in this case sum the score of different documents.

Ok(Box::new(scorer))
}

fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) == doc {
Ok(Explanation::new("AutomatonScorer", 1.0))
Ok(Explanation::new("AutomatonScorer", scorer.score()))
} else {
Err(TantivyError::InvalidArgument(
"Document does not exist".to_string(),
Expand All @@ -83,6 +77,25 @@ where
}
}

fn automaton_score<A>(automaton: &A, state: A::State) -> f32
where
A: Automaton + Send + Sync + 'static,
A::State: Clone,
{
if TypeId::of::<DFAWrapper>() == automaton.type_id() && TypeId::of::<u32>() == state.type_id() {
let dfa = automaton as *const A as *const DFAWrapper;
let dfa = unsafe { &*dfa };

let id = &state as *const A::State as *const u32;
let id = unsafe { *id };

let dist = dfa.0.distance(id).to_u8() as f32;
1.0 / (1.0 + dist)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document this score somewhere in the doc of the FuzzyTermScorer?

Ideally I think this should be implemented as an array [f32; 5] that is passed by the user. (5 is sufficient. We do not handle any levenshtein distance above 4.)

I'm ok with [1f32, 0.5f32, 0.3333f32, ...] as a default.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure, isn't max distance levenshtein 2 ? In this range

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or the transposition below multiplies it by 2 again ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes you are right!
The limit of 4 is within the levenshtein automata library. I reduced it in tantivy.

} else {
1.0
}
}

#[cfg(test)]
mod tests {
use super::AutomatonWeight;
Expand Down
2 changes: 1 addition & 1 deletion src/query/fuzzy_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ mod test {
.unwrap();
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
assert_nearly_equals!(0.5, score);
}

// fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')
Expand Down
4 changes: 3 additions & 1 deletion src/termdict/fst_termdict/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ mod streamer;
mod term_info_store;
mod termdict;

pub use self::streamer::{TermStreamer, TermStreamerBuilder};
pub use self::streamer::{
TermStreamer, TermStreamerBuilder, TermWithStateStreamer, TermWithStateStreamerBuilder,
};
pub use self::termdict::{TermDictionary, TermDictionaryBuilder};
152 changes: 151 additions & 1 deletion src/termdict/fst_termdict/streamer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::TermDictionary;
use crate::postings::TermInfo;
use crate::termdict::TermOrdinal;
use tantivy_fst::automaton::AlwaysMatch;
use tantivy_fst::map::{Stream, StreamBuilder};
use tantivy_fst::map::{Stream, StreamBuilder, StreamWithState};
use tantivy_fst::Automaton;
use tantivy_fst::{IntoStreamer, Streamer};

Expand Down Expand Up @@ -149,3 +149,153 @@ where
}
}
}

/// `TermWithStateStreamerBuilder` is a helper object used to define
/// a range of terms that should be streamed.
pub struct TermWithStateStreamerBuilder<'a, A = AlwaysMatch>
Copy link
Collaborator

@fulmicoton fulmicoton Mar 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok to just have a TermWithStateStreamer and avoid the code duplication.

Most of the time the state is small and "Copy".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is it OK if I remove the extra TermWithState* structs I added, and instead add the state to all the TermStreamer* structs as @maufl also suggested ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

where
A: Automaton,
A::State: Clone,
{
fst_map: &'a TermDictionary,
stream_builder: StreamBuilder<'a, A>,
}

impl<'a, A> TermWithStateStreamerBuilder<'a, A>
where
A: Automaton,
A::State: Clone,
{
pub(crate) fn new(fst_map: &'a TermDictionary, stream_builder: StreamBuilder<'a, A>) -> Self {
TermWithStateStreamerBuilder {
fst_map,
stream_builder,
}
}

/// Limit the range to terms greater or equal to the bound
pub fn ge<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.ge(bound);
self
}

/// Limit the range to terms strictly greater than the bound
pub fn gt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.gt(bound);
self
}

/// Limit the range to terms lesser or equal to the bound
pub fn le<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.le(bound);
self
}

/// Limit the range to terms lesser or equal to the bound
pub fn lt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.lt(bound);
self
}

/// Iterate over the range backwards.
pub fn backward(mut self) -> Self {
self.stream_builder = self.stream_builder.backward();
self
}

/// Creates the stream corresponding to the range
/// of terms defined using the `TermWithStateStreamerBuilder`.
pub fn into_stream(self) -> io::Result<TermWithStateStreamer<'a, A>> {
Ok(TermWithStateStreamer {
fst_map: self.fst_map,
stream: self.stream_builder.with_state().into_stream(),
term_ord: 0u64,
current_key: Vec::with_capacity(100),
current_value: TermInfo::default(),
current_state: None,
})
}
}

/// `TermWithStateStreamer` acts as a cursor over a range of terms of a segment.
/// Terms are guaranteed to be sorted.
pub struct TermWithStateStreamer<'a, A = AlwaysMatch>
where
A: Automaton,
A::State: Clone,
{
fst_map: &'a TermDictionary,
stream: StreamWithState<'a, A>,
term_ord: TermOrdinal,
current_key: Vec<u8>,
current_value: TermInfo,
current_state: Option<A::State>,
}

impl<'a, A> TermWithStateStreamer<'a, A>
where
A: Automaton,
A::State: Clone,
{
/// Advance position the stream on the next item.
/// Before the first call to `.advance()`, the stream
/// is an unitialized state.
pub fn advance(&mut self) -> bool {
if let Some((term, term_ord, state)) = self.stream.next() {
self.current_key.clear();
self.current_key.extend_from_slice(term);
self.term_ord = term_ord;
self.current_value = self.fst_map.term_info_from_ord(term_ord);
self.current_state = Some(state);
true
} else {
false
}
}

/// Returns the `TermOrdinal` of the given term.
///
/// May panic if the called as `.advance()` as never
/// been called before.
pub fn term_ord(&self) -> TermOrdinal {
self.term_ord
}

/// Accesses the current key.
///
/// `.key()` should return the key that was returned
/// by the `.next()` method.
///
/// If the end of the stream as been reached, and `.next()`
/// has been called and returned `None`, `.key()` remains
/// the value of the last key encountered.
///
/// Before any call to `.next()`, `.key()` returns an empty array.
pub fn key(&self) -> &[u8] {
&self.current_key
}

/// Accesses the current value.
///
/// Calling `.value()` after the end of the stream will return the
/// last `.value()` encountered.
///
/// # Panics
///
/// Calling `.value()` before the first call to `.advance()` returns
/// `V::default()`.
pub fn value(&self) -> &TermInfo {
&self.current_value
}

/// Return the next `(key, value, state)` triplet.
#[cfg_attr(feature = "cargo-clippy", allow(clippy::should_implement_trait))]
pub fn next(&mut self) -> Option<(&[u8], &TermInfo, A::State)> {
if self.advance() {
let state = self.current_state.take().unwrap(); // always Some(_) after advance
Some((self.key(), self.value(), state))
} else {
None
}
}
}
13 changes: 12 additions & 1 deletion src/termdict/fst_termdict/termdict.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::term_info_store::{TermInfoStore, TermInfoStoreWriter};
use super::{TermStreamer, TermStreamerBuilder};
use super::{TermStreamer, TermStreamerBuilder, TermWithStateStreamerBuilder};
use crate::common::{BinarySerializable, CountingWriter};
use crate::directory::{FileSlice, OwnedBytes};
use crate::error::DataCorruption;
Expand Down Expand Up @@ -201,4 +201,15 @@ impl TermDictionary {
let stream_builder = self.fst_index.search(automaton);
TermStreamerBuilder::<A>::new(self, stream_builder)
}

/// Returns a search builder, to stream all of the terms
/// within the Automaton
pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A>
where
A: Automaton + 'a,
A::State: Clone,
{
let stream_builder = self.fst_index.search(automaton);
TermWithStateStreamerBuilder::<A>::new(self, stream_builder)
}
}
4 changes: 4 additions & 0 deletions src/termdict/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,7 @@ pub type TermMerger<'a> = self::merger::TermMerger<'a>;
/// `TermStreamer` acts as a cursor over a range of terms of a segment.
/// Terms are guaranteed to be sorted.
pub type TermStreamer<'a, A = AlwaysMatch> = self::termdict::TermStreamer<'a, A>;

/// `TermWithStateStreamer` acts as a cursor over a range of terms of a segment.
/// Terms are guaranteed to be sorted.
pub type TermWithStateStreamer<'a, A = AlwaysMatch> = self::termdict::TermWithStateStreamer<'a, A>;