Skip to content

Commit 7bc6d37

Browse files
authored
Added device parameter to allow usage with MPS (Apple Silicon) and Vulkan (#32)
* Bumps dependencies for compatibility with libtorch 2.2 * Adds device parameter to enable usage with MPS (Apple Silicon) and Vulkan * Disables default features for rust-bert
1 parent 89d505c commit 7bc6d37

File tree

10 files changed

+25
-24
lines changed

10 files changed

+25
-24
lines changed

Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "sbert"
3-
version = "0.4.1"
3+
version = "0.5.0"
44
authors = ["Chady Dimachkie <cpcdoy@gmail.com>"]
55
edition = "2018"
66
description = "Rust implementation of Sentence Bert (SBert)"
@@ -16,15 +16,15 @@ log = "0.4"
1616
num_cpus = "1.13"
1717
prost = "0.9"
1818
rayon = "1.5"
19-
rust-bert = "0.21.0"
19+
rust-bert = { git = "https://github.com/guillaume-be/rust-bert", rev = "29f9a7a", default-features = false }
2020
rust_tokenizers = "7.0"
2121
serde = "1.0"
2222
strum = "0.23"
2323
strum_macros = "0.23"
24-
tch = "0.13.0"
24+
tch = "0.15.0"
2525
thiserror = "1.0"
26-
tokenizers = "0.11"
27-
torch-sys = "0.13.0"
26+
tokenizers = "0.15"
27+
torch-sys = "0.15.0"
2828

2929
[dev-dependencies]
3030
criterion = "0.3"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ You can use different versions of the models that use different tokenizers:
3333

3434
```Rust
3535
// To use Hugging Face tokenizer
36-
let sbert_model = SBertHF::new(home.to_str().unwrap());
36+
let sbert_model = SBertHF::new(home.to_str().unwrap(), None);
3737

3838
// To use Rust-tokenizers
39-
let sbert_model = SBertRT::new(home.to_str().unwrap());
39+
let sbert_model = SBertRT::new(home.to_str().unwrap(), None);
4040
```
4141

4242
Now, you can encode your sentences:

benches/bench_distilroberta.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ fn bench_distilroberta_rust_tokenizers_sentencepiece(c: &mut Criterion) {
1919
home.push("distilroberta_toxicity");
2020

2121
println!("Loading distilroberta ...");
22-
let sbert_model = DistilRobertaForSequenceClassificationRT::new(home).unwrap();
22+
let sbert_model = DistilRobertaForSequenceClassificationRT::new(home, None).unwrap();
2323

2424
let text = "TTThis player needs tp be reported lolz.";
2525
c.bench_function(

benches/bench_sbert.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ fn bench_sbert_rust_tokenizers(c: &mut Criterion) {
2626
home.push("distiluse-base-multilingual-cased");
2727

2828
println!("Loading sbert ...");
29-
let sbert_model = SBertRT::new(home).unwrap();
29+
let sbert_model = SBertRT::new(home, None).unwrap();
3030

3131
let text = "TTThis player needs tp be reported lolz.";
3232
c.bench_function("Encode batch, safe sbert rust tokenizer, total 1", |b| {
@@ -53,7 +53,7 @@ fn bench_sbert_hugging_face_tokenizers(c: &mut Criterion) {
5353
home.push("distiluse-base-multilingual-cased");
5454

5555
println!("Loading sbert ...");
56-
let sbert_model = SBertHF::new(home).unwrap();
56+
let sbert_model = SBertHF::new(home, None).unwrap();
5757

5858
let text = "TTThis player needs tp be reported lolz.";
5959
c.bench_function(

src/layers/dense.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,14 @@ impl Config for DenseConfig {}
2626
pub struct Dense {
2727
linear: nn::Linear,
2828
_conf: DenseConfig,
29+
2930
}
3031

3132
impl Dense {
32-
pub fn new<P: Into<PathBuf>>(root: P) -> Result<Dense, Error> {
33+
pub fn new<P: Into<PathBuf>>(root: P, device: Device) -> Result<Dense, Error> {
3334
let dense_dir = root.into().join("2_Dense");
3435
log::info!("Loading conf {:?}", dense_dir);
3536

36-
let device = Device::cuda_if_available();
37-
//let device = Device::Cpu;
3837
let mut vs_dense = nn::VarStore::new(device);
3938

4039
let init_conf = nn::LinearConfig {

src/models/distilroberta.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl<T> DistilRobertaForSequenceClassification<T>
2323
where
2424
T: Tokenizer + Send + Sync,
2525
{
26-
pub fn new<P>(root: P) -> Result<Self, Error>
26+
pub fn new<P>(root: P, device: Option<Device>) -> Result<Self, Error>
2727
where
2828
P: Into<PathBuf>,
2929
{
@@ -36,7 +36,7 @@ where
3636

3737
let config = BertConfig::from_file(&config_file);
3838

39-
let device = Device::cuda_if_available();
39+
let device = device.unwrap_or(Device::cuda_if_available());
4040
log::info!("Using device {:?}", device);
4141

4242
let mut vs = nn::VarStore::new(device);

src/models/sbert.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ impl<T> SBert<T>
2727
where
2828
T: Tokenizer + Send + Sync,
2929
{
30-
pub fn new<P>(root: P) -> Result<Self, Error>
30+
pub fn new<P>(root: P, device: Option<Device>) -> Result<Self, Error>
3131
where
3232
P: Into<PathBuf>,
3333
{
@@ -44,11 +44,13 @@ where
4444
let nb_layers = config.n_layers as usize;
4545
let nb_heads = config.n_heads as usize;
4646

47+
let device = device.unwrap_or(Device::cuda_if_available());
48+
log::info!("Using device {:?}", device);
49+
4750
let pooling = Pooling::new(root.clone());
48-
let dense = Dense::new(root)?;
51+
let dense = Dense::new(root, device)?;
52+
4953

50-
let device = Device::cuda_if_available();
51-
log::info!("Using device {:?}", device);
5254

5355
let mut vs = nn::VarStore::new(device);
5456

src/tokenizers/hf_tokenizers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ impl Tokenizer for HFTokenizer {
6161
let stride = 0;
6262
let strategy = TruncationStrategy::LongestFirst;
6363
let direction = TruncationDirection::Right;
64-
tokenizer.with_truncation(Some(TruncationParams {
64+
let _ = tokenizer.with_truncation(Some(TruncationParams {
6565
max_length,
6666
stride,
6767
strategy,

tests/test_distilroberta.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ mod tests {
4949

5050
println!("Loading distilroberta ...");
5151
let before = Instant::now();
52-
let sbert_model = DistilRobertaForSequenceClassificationRT::new(home).unwrap();
52+
let sbert_model = DistilRobertaForSequenceClassificationRT::new(home, None).unwrap();
5353
println!("Elapsed time: {:.2?}", before.elapsed());
5454

5555
let mut texts = Vec::new();

tests/test_sbert.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ mod tests {
6666

6767
println!("Loading sbert ...");
6868
let before = Instant::now();
69-
let sbert_model = SBertRT::new(home).unwrap();
69+
let sbert_model = SBertRT::new(home, None).unwrap();
7070
println!("Elapsed time: {:.2?}", before.elapsed());
7171

7272
let mut texts = Vec::new();
@@ -104,7 +104,7 @@ mod tests {
104104

105105
println!("Loading sbert ...");
106106
let before = Instant::now();
107-
let sbert_model = SBertHF::new(home).unwrap();
107+
let sbert_model = SBertHF::new(home, None).unwrap();
108108
println!("Elapsed time: {:.2?}", before.elapsed());
109109

110110
let mut texts = Vec::new();
@@ -137,7 +137,7 @@ mod tests {
137137

138138
println!("Loading sbert ...");
139139
let before = Instant::now();
140-
let sbert_model = SBertHF::new(home).unwrap();
140+
let sbert_model = SBertHF::new(home, None).unwrap();
141141
println!("Elapsed time: {:.2?}", before.elapsed());
142142

143143
let mut texts = Vec::new();

0 commit comments

Comments
 (0)