Skip to content

Commit

Permalink
Change sorting and early finishing in the main algorithms.
Browse files Browse the repository at this point in the history
Previously there was a "bug" where wrong score was compared and result scanning
finished early. Currently the cutoff is configurable, data is sorted first and
when limit is reached the phrase scanning stops. For each phrase we sort tokens
and stop measuring levenshtein distance on first valid match.

Data is usually sorted by score (decreasing), then length (increasing) as it's
best to have high score from shorter phrase. Maybe trigram scores should always
be divided by amount of letters or trigrams they come from.
  • Loading branch information
Tomasz bla Fortuna committed Jul 5, 2022
1 parent 4068a09 commit 1d17ffb
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 77 deletions.
16 changes: 16 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ lazy_static = "1"
unicode-segmentation = "1"
unicode-normalization= "0.1.19"
unicode_categories = "0.1"
itertools = "=0.10"

# Requires AESNI extensions
# As hashmaps/hashsets are used extensively it speeds up some testcases
Expand Down
199 changes: 127 additions & 72 deletions src/fuzzdex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Mutex;
use itertools::Itertools;

use lru::LruCache;
use super::utils;
Expand All @@ -28,13 +29,40 @@ pub struct Result<'a> {
pub should_score: f32,
}

/* TODO Maybe instead of cloning - use Rc<>? */
/* Trigram heatmap is a partial query result */
#[derive(Debug, Clone)]
struct PhraseHeatmap {
/* Token trigram score */
tokens: HashMap<u16, f32, FastHash>,
/* Total phrase score */
total_score: f32,
}

impl PhraseHeatmap {
fn new() -> PhraseHeatmap {
PhraseHeatmap {
tokens: HashMap::with_hasher(FastHash::new()),
total_score: 0.0,
}
}
}

#[derive(Debug, Clone)]
struct Heatmap {
/* Trigram score */
/* phrase_idx -> token_idx -> score */
score: HashMap<usize, HashMap<u16, f32, FastHash>, FastHash>,
max: f32,
phrases: HashMap<usize, PhraseHeatmap, FastHash>,
/* Max phrase score */
max_score: f32,
}

impl Heatmap {
fn new() -> Heatmap {
Heatmap {
phrases: HashMap::with_capacity_and_hasher(8, FastHash::new()),
max_score: 0.0,
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -75,7 +103,7 @@ pub struct Index {
phrases: HashMap<usize, PhraseEntry, FastHash>,

/// LRU cache of must tokens.
cache: Mutex<LruCache<String, Arc<Heatmap>>>,
cache: Mutex<LruCache<String, Arc<Heatmap>, FastHash>>,
}

/// Produced by Index::finish() and can be queried.
Expand All @@ -87,7 +115,7 @@ impl Index {
Index {
db: HashMap::with_capacity_and_hasher(32768, FastHash::new()),
phrases: HashMap::with_hasher(FastHash::new()),
cache: Mutex::new(LruCache::new(30000)),
cache: Mutex::new(LruCache::with_hasher(30000, FastHash::new())),
}
}

Expand Down Expand Up @@ -150,49 +178,52 @@ impl IndexReady {

/// Create trigram heatmap for a given token.
fn create_heatmap(&self, token: &str) -> Arc<Heatmap> {
let db = &self.0.db;
let index = &self.0;
let db = &index.db;

/* LRU cache updates position even on get and needs mutable reference */
{
let mut cache = self.0.cache.lock().unwrap();
let mut cache = index.cache.lock().unwrap();
if let Some(heatmap) = cache.get(token) {
/* We operate on reference-counted heatmaps to eliminate unnecessary copying */
return heatmap.clone();
}
}

let mut heatmap = Heatmap {
score: HashMap::with_capacity_and_hasher(1024, FastHash::new()),
max: 0.0,
};
let mut heatmap = Heatmap::new();

for trigram in utils::trigramize(token) {
if let Some(entry) = db.get(&trigram) {
for position in entry.positions.iter() {
let by_token = heatmap.score.entry(position.phrase_idx).or_insert_with(
|| HashMap::with_capacity_and_hasher(32, FastHash::new()));
let token_score = by_token.entry(position.token_idx).or_insert(0.0);
/* Get or create phrase-level entry */
let phrase_heatmap = heatmap.phrases.entry(position.phrase_idx).or_insert_with(
PhraseHeatmap::new);

/* Get or create token-level entry */
let token_score = phrase_heatmap.tokens.entry(position.token_idx).or_insert(0.0);
*token_score += entry.score;

if *token_score > heatmap.max {
heatmap.max = *token_score;
phrase_heatmap.total_score += entry.score;
if phrase_heatmap.total_score > heatmap.max_score {
heatmap.max_score = phrase_heatmap.total_score;
}
}
}
}

let heatmap = Arc::new(heatmap);
{
let mut cache = self.0.cache.lock().unwrap();
let mut cache = index.cache.lock().unwrap();
cache.put(token.to_string(), heatmap.clone());
}
heatmap
}

fn should_scores(&self, heatmap: &Heatmap, should_tokens: &[String])
-> HashMap<usize, f32, FastHash> {
let mut map: HashMap<usize, f32, FastHash> = HashMap::with_capacity_and_hasher(heatmap.score.len(),
FastHash::new());
let mut map: HashMap<usize, f32, FastHash> = HashMap::with_capacity_and_hasher(
heatmap.phrases.len(), FastHash::new()
);
let db = &self.0.db;

for token in should_tokens {
Expand All @@ -202,7 +233,7 @@ impl IndexReady {
for trigram in utils::trigramize(token) {
if let Some(entry) = db.get(&trigram) {
for position in entry.positions.iter() {
if heatmap.score.contains_key(&position.phrase_idx) {
if heatmap.phrases.contains_key(&position.phrase_idx) {
/* This phrase is within heatmap, we can calculate should score */
let score = map.entry(position.phrase_idx).or_insert(0.0);
*score += entry.score;
Expand All @@ -216,85 +247,96 @@ impl IndexReady {

fn filtered_results(&self, query: &Query, heatmap: &Heatmap,
should_scores: HashMap<usize, f32, FastHash>) -> Vec<Result> {
let mut results: Vec<Result> = Vec::new();
let mut results: Vec<Result> = Vec::with_capacity(query.limit.unwrap_or(3));
if let Some(limit) = query.limit {
results.reserve(limit);
}
let index = &self.0;
let max_distance: usize = query.max_distance.unwrap_or(100);


/* TODO: For now, we convert all entries into results, we could stop earlier */
for (phrase_idx, tokens) in heatmap.score.iter() {
let phrase = &index.phrases[phrase_idx];
if let Some(constraint) = query.constraint {
if !phrase.constraints.contains(&constraint) {
continue
let max_distance: usize = query.max_distance.unwrap_or(usize::MAX);
let limit: usize = query.limit.unwrap_or(usize::MAX);

let phrases_by_score = heatmap.phrases
.iter()
.map(|(idx, heatmap)| {
let phrase = &index.phrases[idx];
(idx, heatmap, phrase, *should_scores.get(idx).unwrap_or(&0.0))
})
.filter(|(_, _, phrase, _)| {
if let Some(constraint) = query.constraint {
phrase.constraints.contains(&constraint)
} else {
/* No constraint - return all */
true
}
})
.sorted_by(|(_, heat_a, phrase_a, should_a), (_, heat_b, phrase_b, should_b)| {
/* Sort by score and then by a should score; for identical - prefer shortest. */
(heat_b.total_score, should_b, phrase_a.origin.len()).partial_cmp(
&(heat_a.total_score, should_a, phrase_b.origin.len())).unwrap()
});

for (phrase_idx, phrase_heatmap, phrase, should_score) in phrases_by_score {
/* Iterate over potential phrases */

/* Drop scanning if the total score dropped below the cutoff*leader. */
if phrase_heatmap.total_score < query.scan_cutoff * heatmap.max_score {
// If the score is too low - it won't grow.
break;
}

let mut valid_tokens: Vec<(&String, usize, f32)> = tokens.iter()
/* Cut off low scoring tokens. TODO: Use trigram count */
.filter(|(_idx, score)| (**score) > 0.4 * heatmap.max)
/* Measure levenhstein distance if filter is enabled */
.map(|(idx, score)| {
let token = &phrase.tokens[*idx as usize];
if query.max_distance.is_some() {
let distance = utils::distance(&token, &query.must);
(token, distance, *score)
} else {
(token, 0, *score)
}
/* Iterate over tokens by decreasing trigram score until first matching is found */
let valid_token = phrase_heatmap.tokens
.iter()
.map(|(&idx, &score)|
(score, &phrase.tokens[idx as usize]))
.sorted_by(|(score_a, token_a), (score_b, token_b)| {
/* Prefer shortest for a given score */
/* TODO: Maybe score could be divided by token length */
let side_a = (score_a, token_b.len());
let side_b = (score_b, token_a.len());
side_b.partial_cmp(&side_a).unwrap()
})
/* Drop ones that are too far away */
.filter(|(_token, distance, _score)| distance <= &max_distance)
.collect();

valid_tokens.sort_unstable_by_key(
/* Solves PartialOrd for floats in a peculiar way. Should be fine though. */
|(token, distance, score)| (*distance,
- ((*score) * 10000.0) as i64,
(token.len() as i32))
);

if !valid_tokens.is_empty() {
.map(|(token_score, token)| {
let distance = utils::distance(token, &query.must);
(token, token_score, distance)
}).find(|(_token, _score, distance)| {
*distance <= max_distance
});

if let Some((token, token_score, distance)) = valid_token {
/* Add result based on best token matching this phrase (lowest
* distance, highest score) */

let best = valid_tokens[0];
let should_score: f32 = *should_scores.get(phrase_idx).unwrap_or(&0.0);
results.push(
Result {
origin: &phrase.origin,
index: *phrase_idx,
token: best.0,
distance: best.1,
score: best.2,
score: token_score,
should_score,
token,
distance,
});

/* Early break if we reached limit */
if results.len() >= limit {
break;
}
}
}

results.sort_unstable_by_key(|result|
(result.distance,
(- 1000.0 * result.score) as i64,
(- 1000.0 * result.should_score) as i64,
result.origin.len())

);

if let Some(limit) = query.limit {
results.truncate(limit);
}
results.sort_unstable_by(|a, b| {
let side_a = (a.distance, -a.score, -a.should_score, a.origin.len());
let side_b = (b.distance, -b.score, -b.should_score, b.origin.len());
side_a.partial_cmp(&side_b).unwrap()
});

results
}

pub fn search(&self, query: &Query) -> Vec<Result> {
let heatmap = self.create_heatmap(&query.must);
let should_scores = self.should_scores(&heatmap, &query.should);
let results = self.filtered_results(query, &heatmap, should_scores);
results
self.filtered_results(query, &heatmap, should_scores)
}
}

Expand All @@ -312,6 +354,7 @@ mod tests {
idx.add_phrase("This is an entry", 1, None);
idx.add_phrase("Another entry entered.", 2, Some(&constraints));
idx.add_phrase("Another about the testing.", 3, None);
idx.add_phrase("Tester tested a test suite.", 4, None);
let idx = idx.finish();

/* First query */
Expand Down Expand Up @@ -349,5 +392,17 @@ mod tests {
assert_eq!(results.len(), 1);
assert_eq!(results[0].index, 1);
assert!(results[0].should_score > 0.0, "First result should have non-zero should-score");

/* Test multiple tokens matching in single phrase */
let query = Query::new("test", &[]).limit(Some(60));
println!("Querying {:?}", query);
let results = idx.search(&query);

for result in &results {
println!("Got result {:?}", result);
}

assert_eq!(results.len(), 1);
assert_eq!(results[0].index, 4);
}
}
8 changes: 5 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ impl FuzzDex {
fn search<'py>(&self, py: Python<'py>,
must: &str, should: Vec<&str>,
constraint: Option<usize>, limit: Option<usize>,
max_distance: Option<usize>) -> PyResult<PyObject> {
max_distance: Option<usize>,
scan_cutoff: Option<f32>) -> PyResult<PyObject> {
match &self.index_ready {
None => {
Err(PyErr::new::<exceptions::PyRuntimeError, _>("Index is not yet finished."))
Expand All @@ -67,7 +68,8 @@ impl FuzzDex {
let query = query::Query::new(must, &should)
.constraint(constraint)
.max_distance(max_distance)
.limit(limit);
.limit(limit)
.scan_cutoff(scan_cutoff.unwrap_or(0.3));

let search_results = py.allow_threads(
move || {
Expand All @@ -93,7 +95,7 @@ impl FuzzDex {

}

/// Helper to calculate levenhstein distance from Python without additional libs.
/// Helper to calculate levenshtein distance from Python without additional libs.
#[pyfunction]
fn distance(side_a: &str, side_b: &str) -> PyResult<usize> {
Ok(utils::distance(side_a, side_b))
Expand Down
Loading

0 comments on commit 1d17ffb

Please sign in to comment.