Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 57 additions & 89 deletions python/zvec/extension/sentence_transformer_embedding_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

from ..common.constants import TEXT, DenseVectorType, SparseVectorType
from ..tool import require_module
from .embedding_function import DenseEmbeddingFunction, SparseEmbeddingFunction
from .sentence_transformer_function import SentenceTransformerFunctionBase

Expand All @@ -39,6 +40,9 @@ class DefaultLocalDenseEmbedding(
similarity tasks. It runs locally without requiring API keys.

Args:
model_name (Optional[str]): Model identifier or local path. Defaults to:
- ``"all-MiniLM-L6-v2"`` for Hugging Face
- ``"iic/nlp_gte_sentence-embedding_chinese-small"`` for ModelScope
model_source (Literal["huggingface", "modelscope"], optional): Model source.
- ``"huggingface"``: Use Hugging Face Hub (default, for international users)
- ``"modelscope"``: Use ModelScope (recommended for users in China)
Expand Down Expand Up @@ -153,6 +157,7 @@ class DefaultLocalDenseEmbedding(

def __init__(
self,
model_name: Optional[str] = None,
model_source: Literal["huggingface", "modelscope"] = "huggingface",
device: Optional[str] = None,
normalize_embeddings: bool = True,
Expand All @@ -162,6 +167,9 @@ def __init__(
"""Initialize with all-MiniLM-L6-v2 model.

Args:
model_name (Optional[str]): Model identifier or local path. Defaults to:
- ``"all-MiniLM-L6-v2"`` for Hugging Face
- ``"iic/nlp_gte_sentence-embedding_chinese-small"`` for ModelScope
model_source (Literal["huggingface", "modelscope"]): Model source.
Defaults to "huggingface".
device (Optional[str]): Target device ("cpu", "cuda", "mps", or None).
Expand All @@ -176,11 +184,12 @@ def __init__(
ValueError: If model cannot be loaded.
"""
# Use different models based on source
if model_source == "modelscope":
# Use Chinese-optimized model for ModelScope (better for Chinese text)
model_name = "iic/nlp_gte_sentence-embedding_chinese-small"
else:
model_name = "all-MiniLM-L6-v2"
if model_name is None:
if model_source == "modelscope":
# Use Chinese-optimized model for ModelScope (better for Chinese text)
model_name = "iic/nlp_gte_sentence-embedding_chinese-small"
else:
model_name = "all-MiniLM-L6-v2"

# Initialize base class for model loading
SentenceTransformerFunctionBase.__init__(
Expand All @@ -197,6 +206,20 @@ def __init__(
# Store extra parameters
self._extra_params = kwargs

@property
def _get_model_class(self):
"""Get the Sentence Transformer class.

Returns:
class: SentenceTransformer, the class used for dense embeddings.

Raises:
ImportError: If required packages are not installed.
"""
sentence_transformers = require_module("sentence_transformers")

return sentence_transformers.SentenceTransformer

@property
def dimension(self) -> int:
"""int: The expected dimensionality of the embedding vector."""
Expand Down Expand Up @@ -368,6 +391,8 @@ class DefaultLocalSparseEmbedding(
``SparseEmbeddingFunction``.

Args:
model_name (Optional[str]): Model identifier or local path. Defaults to
``"naver/splade-cocondenser-ensembledistil"`` if None.
model_source (Literal["huggingface", "modelscope"], optional): Model source.
Defaults to ``"huggingface"``. ModelScope support may vary for SPLADE models.
device (Optional[str], optional): Device to run the model on.
Expand Down Expand Up @@ -589,6 +614,7 @@ def remove_from_cache(

def __init__(
self,
model_name: Optional[str] = None,
model_source: Literal["huggingface", "modelscope"] = "huggingface",
device: Optional[str] = None,
encoding_type: Literal["query", "document"] = "query",
Expand All @@ -597,6 +623,8 @@ def __init__(
"""Initialize with SPLADE model.

Args:
model_name (Optional[str]): Model identifier or local path. Defaults to
``"naver/splade-cocondenser-ensembledistil"`` if None.
model_source (Literal["huggingface", "modelscope"]): Model source.
Defaults to "huggingface".
device (Optional[str]): Target device ("cpu", "cuda", "mps", or None).
Expand Down Expand Up @@ -640,7 +668,8 @@ def __init__(
# Use publicly available SPLADE model (no gated access required)
# Note: naver/splade-v3 requires authentication, so we use the
# cocondenser-ensembledistil variant which is publicly accessible
model_name = "naver/splade-cocondenser-ensembledistil"
if model_name is None:
model_name = "naver/splade-cocondenser-ensembledistil"

# Initialize base class for model loading
SentenceTransformerFunctionBase.__init__(
Expand All @@ -656,6 +685,20 @@ def __init__(
# Load model to ensure it's available (will use cache if exists)
self._get_model()

@property
def _get_model_class(self):
"""Get the Sentence Transformer class based on the model source.

Returns:
class: SparseEncoder, the class used for SPLADE sparse embeddings.

Raises:
ImportError: If required packages are not installed.
"""
sentence_transformers = require_module("sentence_transformers")

return sentence_transformers.SparseEncoder

@property
def extra_params(self) -> dict:
"""dict: Extra parameters for model-specific customization."""
Expand Down Expand Up @@ -714,41 +757,19 @@ def embed(self, input: str) -> SparseVectorType:
model = self._get_model()

# Use appropriate encoding method based on type
if self._encoding_type == "document" and hasattr(model, "encode_document"):
if self._encoding_type == "document":
# Use document encoding
sparse_matrix = model.encode_document([input])
elif hasattr(model, "encode_query"):
else:
# Use query encoding (default)
sparse_matrix = model.encode_query([input])
else:
# Fallback: manual implementation for older sentence-transformers
return self._manual_sparse_encode(input)

# Convert sparse matrix to dictionary
# SPLADE returns shape [1, vocab_size] for single input

# Check if it's a sparse matrix (duck typing - has toarray method)
if hasattr(sparse_matrix, "toarray"):
# Sparse matrix (CSR/CSC/etc.) - convert to dense array
sparse_array = sparse_matrix[0].toarray().flatten()
sparse_dict = {
int(idx): float(val)
for idx, val in enumerate(sparse_array)
if val > 0
}
else:
# Dense array format (numpy array or similar)
if isinstance(sparse_matrix, np.ndarray):
sparse_array = sparse_matrix[0]
else:
sparse_array = sparse_matrix

sparse_dict = {
int(idx): float(val)
for idx, val in enumerate(sparse_array)
if val > 0
}

# The decode method returns a list of (token_string, score) pairs for non-zero dimensions
# Then we post-process the tokens to IDs again
decoded = model.decode(sparse_matrix)[0]
token_strings, scores = zip(*decoded, strict=True)
token_ids = model.tokenizer.convert_tokens_to_ids(token_strings)
sparse_dict = dict(zip(token_ids, scores, strict=True))
# Sort by indices (keys) to ensure consistent ordering
return dict(sorted(sparse_dict.items()))

Expand All @@ -757,59 +778,6 @@ def embed(self, input: str) -> SparseVectorType:
raise
raise RuntimeError(f"Failed to generate sparse embedding: {e!s}") from e

def _manual_sparse_encode(self, input: str) -> SparseVectorType:
"""Fallback manual SPLADE encoding for older sentence-transformers.

Args:
input (str): Input text to encode.

Returns:
SparseVectorType: Sparse vector as dictionary.
"""
import torch

model = self._get_model()

# Tokenize input
features = model.tokenize([input])

# Move to correct device
features = {k: v.to(model.device) for k, v in features.items()}

# Forward pass with no gradient
with torch.no_grad():
embeddings = model.forward(features)

# Get logits from model output
# SPLADE models typically output 'token_embeddings'
if isinstance(embeddings, dict) and "token_embeddings" in embeddings:
logits = embeddings["token_embeddings"][0] # First batch item
elif hasattr(embeddings, "token_embeddings"):
logits = embeddings.token_embeddings[0]
# Fallback: try to get first value
elif isinstance(embeddings, dict):
logits = next(iter(embeddings.values()))[0]
else:
logits = embeddings[0]

# Apply SPLADE activation: log(1 + relu(x))
relu_log = torch.log(1 + torch.relu(logits))

# Max pooling over token dimension (reduce to vocab size)
if relu_log.dim() > 1:
sparse_vec, _ = torch.max(relu_log, dim=0)
else:
sparse_vec = relu_log

# Convert to sparse dictionary (only non-zero values)
sparse_vec_np = sparse_vec.cpu().numpy()
sparse_dict = {
int(idx): float(val) for idx, val in enumerate(sparse_vec_np) if val > 0
}

# Sort by indices (keys) to ensure consistent ordering
return dict(sorted(sparse_dict.items()))

def _get_model(self):
"""Load or retrieve the SPLADE model from class-level cache.

Expand Down
29 changes: 15 additions & 14 deletions python/zvec/extension/sentence_transformer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ def device(self) -> str:
return str(model.device)
return self._device or "cpu"

@property
def _get_model_class(self):
"""Get the Sentence Transformer class.

Returns:
class: The Sentence Transformer class to use for loading models.

Raises:
ImportError: If required packages are not installed.
"""
raise NotImplementedError()

def _get_model(self):
"""Load or retrieve the Sentence Transformer model.

Expand All @@ -104,8 +116,6 @@ def _get_model(self):

# Load model
try:
sentence_transformers = require_module("sentence_transformers")

if self._model_source == "modelscope":
# Load from ModelScope
require_module("modelscope")
Expand All @@ -115,12 +125,13 @@ def _get_model(self):
model_dir = snapshot_download(self._model_name)

# Load from local path
self._model = sentence_transformers.SentenceTransformer(
self._model = self._get_model_class(
model_dir, device=self._device, trust_remote_code=True
)
else:
# Load from Hugging Face (default)
self._model = sentence_transformers.SentenceTransformer(
self._model = self._get_model_class(

self._model_name, device=self._device, trust_remote_code=True
)

Expand All @@ -138,13 +149,3 @@ def _get_model(self):
f"Failed to load Sentence Transformer model '{self._model_name}' "
f"from {self._model_source}: {e!s}"
) from e

def _is_sparse_model(self) -> bool:
"""Check if the loaded model is a sparse encoder (e.g., SPLADE).

Returns:
bool: True if model supports sparse encoding.
"""
model = self._get_model()
# Check if model has sparse encoding methods
return hasattr(model, "encode_query") or hasattr(model, "encode_document")
Loading