Skip to content

Add Dense layer in 2_Dense/ modules #660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,19 @@ impl Linear {
),
}
} else {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
let (x, w) = match x.dims() {
&[bsize, _, _] => (x, self.weight.broadcast_left(bsize)?.t()?),
// Metal devices require contiguous tensors for 2D matrix multiplication apparently
_ if matches!(x.device(), Device::Metal(_)) => (&x.contiguous()?, self.weight.t()?),
_ => (x, self.weight.t()?),
};
let x = x.matmul(&w)?;

let x = match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}?;

if let Some(act) = &self.act {
match act {
HiddenAct::Gelu => x.gelu(),
Expand Down
64 changes: 61 additions & 3 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig,
Model, ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
Qwen3Config, Qwen3Model,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -114,13 +115,15 @@ enum Config {
pub struct CandleBackend {
device: Device,
model: Box<dyn Model + Send>,
dense: Option<Box<dyn DenseLayer + Send>>,
}

impl CandleBackend {
pub fn new(
model_path: &Path,
dtype: String,
model_type: ModelType,
dense_path: Option<&Path>,
) -> Result<Self, BackendError> {
// Default files
let default_safetensors = model_path.join("model.safetensors");
Expand Down Expand Up @@ -468,9 +471,50 @@ impl CandleBackend {
}
};

// If `2_Dense/model.safetensors` or `2_Dense/pytorch_model.bin` is amongst the downloaded artifacts, then create a Dense
// block and provide it to the `CandleBackend`, otherwise, None
let dense = if let Some(dense_path) = dense_path {
let dense_safetensors = dense_path.join("model.safetensors");
let dense_pytorch = dense_path.join("pytorch_model.bin");

if dense_safetensors.exists() || dense_pytorch.exists() {
let dense_config_path = dense_path.join("config.json");

let dense_config_str =
std::fs::read_to_string(&dense_config_path).map_err(|err| {
BackendError::Start(format!(
"Unable to read `{dense_path:?}/config.json` file: {err:?}",
))
})?;
let dense_config: DenseConfig =
serde_json::from_str(&dense_config_str).map_err(|err| {
BackendError::Start(format!(
"Unable to parse `{dense_path:?}/config.json`: {err:?}",
))
})?;

let dense_vb = if dense_safetensors.exists() {
unsafe {
VarBuilder::from_mmaped_safetensors(&[dense_safetensors], dtype, &device)
}
.s()?
} else {
VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()?
};

Some(Box::new(Dense::load(dense_vb, &dense_config).s()?)
as Box<dyn DenseLayer + Send>)
} else {
None
}
} else {
None
};

Ok(Self {
device,
model: model?,
dense,
})
}
}
Expand Down Expand Up @@ -507,6 +551,19 @@ impl Backend for CandleBackend {
// Run forward
let (pooled_embeddings, raw_embeddings) = self.model.embed(batch).e()?;

// Apply dense layer if available
let pooled_embeddings = match pooled_embeddings {
None => None,
Some(pooled_embeddings) => {
let pooled_embeddings = if let Some(ref dense) = self.dense {
dense.forward(&pooled_embeddings).e()?
} else {
pooled_embeddings
};
Some(pooled_embeddings)
}
};

// Device => Host data transfer
let pooled_embeddings = match pooled_embeddings {
None => vec![],
Expand Down Expand Up @@ -540,6 +597,7 @@ impl Backend for CandleBackend {
let batch_size = batch.len();

let results = self.model.predict(batch).e()?;

let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;

let mut predictions =
Expand Down
75 changes: 75 additions & 0 deletions backends/candle/src/models/dense.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use crate::layers::Linear;
use candle::{Result, Tensor};
use candle_nn::VarBuilder;
use serde::Deserialize;

#[derive(Debug, Clone, Deserialize, PartialEq)]
/// The activation functions in `2_Dense/config.json` are defined as PyTorch imports
pub enum DenseActivation {
#[serde(rename = "torch.nn.modules.activation.Tanh")]
/// e.g. https://huggingface.co/sentence-transformers/LaBSE/blob/main/2_Dense/config.json
Tanh,
#[serde(rename = "torch.nn.modules.linear.Identity")]
/// e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5/blob/main/2_Dense/config.json
Identity,
}

impl DenseActivation {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
match self {
Self::Tanh => x.tanh(),
Self::Identity => Ok(x.clone()),
}
}
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct DenseConfig {
in_features: usize,
out_features: usize,
bias: bool,
activation_function: Option<DenseActivation>,
}

pub trait DenseLayer {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
}

#[derive(Debug)]
pub struct Dense {
linear: Linear,
activation: DenseActivation,
span: tracing::Span,
}

impl Dense {
pub fn load(vb: VarBuilder, config: &DenseConfig) -> Result<Self> {
let weight = vb.get((config.out_features, config.in_features), "linear.weight")?;
let bias = if config.bias {
Some(vb.get(config.out_features, "linear.bias")?)
} else {
None
};
let linear = Linear::new(weight, bias, None);

let activation = config
.activation_function
.clone()
.unwrap_or(DenseActivation::Identity);

Ok(Self {
linear,
activation,
span: tracing::span!(tracing::Level::TRACE, "dense"),
})
}
}

impl DenseLayer for Dense {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let hidden_states = self.linear.forward(hidden_states)?;
self.activation.forward(&hidden_states)
}
}
2 changes: 2 additions & 0 deletions backends/candle/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

mod bert;
mod dense;
mod distilbert;
mod jina;
mod jina_code;
Expand Down Expand Up @@ -49,6 +50,7 @@ mod qwen3;

pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
use candle::{Result, Tensor};
pub use dense::{Dense, DenseConfig, DenseLayer};
pub use distilbert::{DistilBertConfig, DistilBertModel};
#[allow(unused_imports)]
pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP};
Expand Down
35 changes: 35 additions & 0 deletions backends/candle/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pub fn sort_embeddings(embeddings: Embeddings) -> (Vec<Vec<f32>>, Vec<Vec<f32>>)
pub fn download_artifacts(
model_id: &'static str,
revision: Option<&'static str>,
dense_path: Option<&'static str>,
) -> Result<PathBuf> {
let mut builder = ApiBuilder::from_env().with_progress(false);

Expand Down Expand Up @@ -140,6 +141,40 @@ pub fn download_artifacts(
vec![p]
}
};

// Download dense path files if specified
if let Some(dense_path) = dense_path {
let dense_config_path = format!("{}/config.json", dense_path);
match api_repo.get(&dense_config_path) {
Ok(_) => tracing::info!("Downloaded dense config: {}", dense_config_path),
Err(err) => tracing::warn!(
"Could not download dense config {}: {}",
dense_config_path,
err
),
}

// Try to download dense model files (safetensors first, then pytorch)
let dense_safetensors_path = format!("{}/model.safetensors", dense_path);
match api_repo.get(&dense_safetensors_path) {
Ok(_) => tracing::info!("Downloaded dense safetensors: {}", dense_safetensors_path),
Err(_) => {
tracing::warn!("Dense safetensors not found. Trying pytorch_model.bin");
let dense_pytorch_path = format!("{}/pytorch_model.bin", dense_path);
match api_repo.get(&dense_pytorch_path) {
Ok(_) => {
tracing::info!("Downloaded dense pytorch model: {}", dense_pytorch_path)
}
Err(err) => tracing::warn!(
"Could not download dense pytorch model {}: {}",
dense_pytorch_path,
err
),
}
}
}
}

let model_root = model_files[0].parent().unwrap().to_path_buf();
Ok(model_root)
}
Expand Down
Loading
Loading