Skip to content

Commit 8a6b8ae

Browse files
authored
feat(bindings): expose extractor for direct usage in Python (#47)
1 parent 9a1999f commit 8a6b8ae

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

examples/python_basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def main():
1515
extractor = CharacterNgrams(n=2, endmarker="$")
1616
print("Using extractor: CharacterNgrams(n=2, endmarker='$')")
1717

18+
sample_embedding = extractor.apply("Some text")
19+
print(f"Sample embedding for 'Some text': {sample_embedding}")
20+
1821
# Choose a similarity measure.
1922
# Options: Cosine(), Dice(), Jaccard(), Overlap(), ExactMatch()
2023
measure = Cosine()

src/python/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ impl PyCharacterNgrams {
3636
fn new(n: usize, endmarker: &str) -> Self {
3737
Self(CharacterNgrams::new(n, endmarker))
3838
}
39+
40+
fn apply(&self, text: &str) -> Vec<String> {
41+
let mut interner = lasso::Rodeo::default();
42+
let features = self.0.features(text, &mut interner);
43+
44+
features
45+
.into_iter()
46+
.map(|spur| interner.resolve(&spur).to_string())
47+
.collect()
48+
}
3949
}
4050

4151
#[pyclass(name = "WordNgrams")]
@@ -48,6 +58,16 @@ impl PyWordNgrams {
4858
fn new(n: usize, splitter: &str, padder: &str) -> Self {
4959
Self(WordNgrams::new(n, splitter, padder))
5060
}
61+
62+
fn apply(&self, text: &str) -> Vec<String> {
63+
let mut interner = lasso::Rodeo::default();
64+
let features = self.0.features(text, &mut interner);
65+
66+
features
67+
.into_iter()
68+
.map(|spur| interner.resolve(&spur).to_string())
69+
.collect()
70+
}
5171
}
5272

5373
// Wrapper for Measure trait

tests/python/test_bindings.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pytest
2+
from collections import Counter
3+
24
from simstring_rust.database import HashDb
35
from simstring_rust.errors import SearchError
4-
from simstring_rust.extractors import CharacterNgrams
6+
from simstring_rust.extractors import CharacterNgrams, WordNgrams
57
from simstring_rust.measures import Cosine
68
from simstring_rust.searcher import Searcher
79

@@ -66,3 +68,17 @@ def test_search_error_on_invalid_threshold(self):
6668

6769
with pytest.raises(SearchError, match=r"Invalid threshold: 0(\.0)?"):
6870
self.searcher.search("test", 0.0)
71+
72+
def test_character_ngram_apply(self):
73+
extractor = CharacterNgrams(n=2, endmarker="$")
74+
features = extractor.apply("apple")
75+
76+
expected = ["$a1", "ap1", "pp1", "pl1", "le1", "e$1"]
77+
assert Counter(features) == Counter(expected)
78+
79+
def test_word_ngram_apply(self):
80+
extractor = WordNgrams(n=2, splitter=" ", padder="#")
81+
features = extractor.apply("foo bar baz")
82+
83+
expected = ["# foo1", "foo bar1", "bar baz1", "baz #1"]
84+
assert Counter(features) == Counter(expected)

0 commit comments

Comments
 (0)