Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/python_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def main():
extractor = CharacterNgrams(n=2, endmarker="$")
print("Using extractor: CharacterNgrams(n=2, endmarker='$')")

sample_embedding = extractor.apply("Some text")
print(f"Sample embedding for 'Some text': {sample_embedding}")

# Choose a similarity measure.
# Options: Cosine(), Dice(), Jaccard(), Overlap(), ExactMatch()
measure = Cosine()
Expand Down
20 changes: 20 additions & 0 deletions src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ impl PyCharacterNgrams {
fn new(n: usize, endmarker: &str) -> Self {
Self(CharacterNgrams::new(n, endmarker))
}

fn apply(&self, text: &str) -> Vec<String> {
let mut interner = lasso::Rodeo::default();
let features = self.0.features(text, &mut interner);

features
.into_iter()
.map(|spur| interner.resolve(&spur).to_string())
.collect()
}
}

#[pyclass(name = "WordNgrams")]
Expand All @@ -48,6 +58,16 @@ impl PyWordNgrams {
fn new(n: usize, splitter: &str, padder: &str) -> Self {
Self(WordNgrams::new(n, splitter, padder))
}

fn apply(&self, text: &str) -> Vec<String> {
let mut interner = lasso::Rodeo::default();
let features = self.0.features(text, &mut interner);

features
.into_iter()
.map(|spur| interner.resolve(&spur).to_string())
.collect()
}
}

// Wrapper for Measure trait
Expand Down
18 changes: 17 additions & 1 deletion tests/python/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from collections import Counter

from simstring_rust.database import HashDb
from simstring_rust.errors import SearchError
from simstring_rust.extractors import CharacterNgrams
from simstring_rust.extractors import CharacterNgrams, WordNgrams
from simstring_rust.measures import Cosine
from simstring_rust.searcher import Searcher

Expand Down Expand Up @@ -66,3 +68,17 @@ def test_search_error_on_invalid_threshold(self):

with pytest.raises(SearchError, match=r"Invalid threshold: 0(\.0)?"):
self.searcher.search("test", 0.0)

def test_character_ngram_apply(self):
extractor = CharacterNgrams(n=2, endmarker="$")
features = extractor.apply("apple")

expected = ["$a1", "ap1", "pp1", "pl1", "le1", "e$1"]
assert Counter(features) == Counter(expected)

def test_word_ngram_apply(self):
extractor = WordNgrams(n=2, splitter=" ", padder="#")
features = extractor.apply("foo bar baz")

expected = ["# foo1", "foo bar1", "bar baz1", "baz #1"]
assert Counter(features) == Counter(expected)