Skip to content

Commit

Permalink
[rust] Fix camembert and distilbert model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Aug 15, 2024
1 parent e0519ac commit 5e96385
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
8 changes: 4 additions & 4 deletions extensions/tokenizers/rust/src/models/camembert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,14 @@ pub struct CamembertModel {
impl CamembertModel {
pub fn load(vb: VarBuilder, config: &CamembertConfig) -> Result<Self> {
let (embeddings, encoder) = match (
BertEmbeddings::load(vb.pp("embeddings"), config),
BertEncoder::load(vb.pp("encoder"), config),
BertEmbeddings::load(vb.pp("roberta.embeddings"), config),
BertEncoder::load(vb.pp("roberta.encoder"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(vb.pp("camembert.embeddings".to_string()), config),
BertEncoder::load(vb.pp("camembert.encoder".to_string()), config),
BertEmbeddings::load(vb.pp("deberta.embeddings".to_string()), config),
BertEncoder::load(vb.pp("deberta.encoder".to_string()), config),
) {
(embeddings, encoder)
} else {
Expand Down
16 changes: 6 additions & 10 deletions extensions/tokenizers/rust/src/models/distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,13 @@ impl DistilBertModel {
Embeddings::load(vb.pp("embeddings"), config),
Transformer::load(vb.pp("transformer"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Ok(embeddings), Ok(transformer)) => (embeddings, transformer),
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
) {
(embeddings, encoder)
} else {
return Err(err);
}
if let (Ok(embeddings), Ok(transformer)) = (
Embeddings::load(vb.pp("distilbert.embeddings".to_string()), config),
Transformer::load(vb.pp("distilbert.transformer".to_string()), config),
) {
(embeddings, transformer)
} else {
return Err(err);
}
Expand Down

0 comments on commit 5e96385

Please sign in to comment.