Skip to content

Commit 68e155a

Browse files
committed
Refactor: Extract embedding logic to LlamaEmbedding class, Rerank support and fix parallel batching
- Decoupled embedding and rerank logic into `llama_embedding.py`. - Implemented streaming batching for constant memory usage. - Fixed parallel batching errors by enabling `kv_unified`. such as "multiple embeddings in a single call" - Added native `rank()` support for Reranker models. - Added advanced normalization support (Euclidean, Taxicab, MaxInt16). - Added `array`,`json+` output format for raw vector access. The legacy embedding implementation in `llama.py` is now superseded by this optimized approach. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent a11d97a commit 68e155a

File tree

2 files changed

+358
-0
lines changed

2 files changed

+358
-0
lines changed

llama_cpp/llama.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,11 @@ def create_embedding(
10311031
Returns:
10321032
An embedding object.
10331033
"""
1034+
warnings.warn(
1035+
"The `create_embedding` method in `Llama` class is deprecated. "
1036+
"Please migrate to `LlamaEmbedding.create_embedding` for better efficiency.",
1037+
DeprecationWarning,
1038+
)
10341039
model_name: str = model if model is not None else self.model_path
10351040

10361041
input = input if isinstance(input, list) else [input]
@@ -1075,6 +1080,12 @@ def embed(
10751080
Returns:
10761081
A list of embeddings
10771082
"""
1083+
warnings.warn(
1084+
"The `embed` method in `Llama` class is deprecated and will be removed in future versions. "
1085+
"Please use the `LlamaEmbedding` class from `llama_embedding` module for optimized performance and reranking support.",
1086+
DeprecationWarning,
1087+
)
1088+
10781089
n_embd = self.n_embd()
10791090
n_batch = self.n_batch
10801091

llama_cpp/llama_embedding.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
import numpy as np
2+
from typing import Union, List, Optional, Dict, Any, Tuple
3+
import llama_cpp.llama_cpp as llama_cpp
4+
from .llama_types import Embedding
5+
from .llama import Llama
6+
# Pooling types from .llama_cpp
7+
from .llama_cpp import (
8+
LLAMA_POOLING_TYPE_UNSPECIFIED,
9+
LLAMA_POOLING_TYPE_NONE,
10+
LLAMA_POOLING_TYPE_MEAN,
11+
LLAMA_POOLING_TYPE_CLS,
12+
LLAMA_POOLING_TYPE_LAST,
13+
LLAMA_POOLING_TYPE_RANK, # Specifically for Reranking models
14+
)
15+
16+
# Normalization modes for embedding vectors
17+
# See: https://github.com/ggml-org/llama.cpp/tree/master/examples/embedding#--embd-normalize-integer
18+
NORM_MODE_NONE = -1
19+
NORM_MODE_MAX_INT16 = 0
20+
NORM_MODE_TAXICAB = 1
21+
NORM_MODE_EUCLIDEAN = 2
22+
23+
# TODO(JamePeng): Needs more extensive testing with various embedding and reranking models.
24+
class LlamaEmbedding(Llama):
25+
"""
26+
A specialized class for high-performance Text Embedding and Reranking.
27+
Inherits from the base Llama class but is optimized for vector operations.
28+
29+
Key Features:
30+
1. Auto-configuration: Automatically sets embedding=True.
31+
2. Streaming Batch: Handles massive datasets without OOM (Out Of Memory).
32+
3. Native Reranking Support: Specifically handles `LLAMA_POOLING_TYPE_RANK` models (like BGE-Reranker). /
33+
It correctly identifies classification heads to output scalar relevance scores instead of high-dimensional vectors.
34+
4. Advanced Normalization: Implements MaxInt16, Taxicab (L1), and Euclidean (L2) normalization strategies /
35+
using NumPy for optimal performance and compatibility with various vector databases.
36+
"""
37+
38+
def __init__(self, model_path: str, pooling_type: int = LLAMA_POOLING_TYPE_UNSPECIFIED, **kwargs):
39+
"""
40+
Initialize the embedding model with enforced configuration.
41+
42+
Args:
43+
model_path: Path to the GGUF model file.
44+
pooling_type: The pooling strategy used by the model.
45+
- Use `LLAMA_POOLING_TYPE_RANK` (4) for Reranker models.
46+
- Use `LLAMA_POOLING_TYPE_UNSPECIFIED` (-1) to let the model metadata decide (for standard embeddings).
47+
**kwargs: Additional arguments passed to the Llama base class (e.g., n_gpu_layers, n_batch, n_ctx).
48+
"""
49+
kwargs["embedding"] = True
50+
51+
# Enable Unified KV Cache (Crucial for Batching)
52+
# This allows us to assign arbitrary seq_ids in a batch, enabling the parallel /
53+
# encoding of multiple unrelated documents without "invalid seq_id" errors.
54+
kwargs["kv_unified"] = True
55+
56+
# Set pooling type
57+
kwargs["pooling_type"] = pooling_type
58+
59+
super().__init__(model_path=model_path, **kwargs)
60+
61+
if self.verbose:
62+
print(f"LlamaEmbedding initialized with pooling_type: {self.pooling_type()}")
63+
64+
def _normalize_vector(self, vector: List[float], mode: int) -> List[float]:
65+
"""
66+
Apply mathematical normalization to a vector.
67+
Uses numpy for performance.
68+
"""
69+
if mode == NORM_MODE_NONE: return vector
70+
arr = np.array(vector, dtype=np.float32)
71+
72+
# Mode 0: Max Absolute Int16 -> 32760 * x_i / max|x_i|
73+
if mode == NORM_MODE_MAX_INT16:
74+
max_abs = np.max(np.abs(arr))
75+
if max_abs == 0: return vector
76+
return ((arr / max_abs) * 32760.0).tolist()
77+
78+
# Mode 1: Taxicab (L1 Norm) -> x_i / sum|x_i|
79+
elif mode == NORM_MODE_TAXICAB:
80+
norm = np.sum(np.abs(arr))
81+
if norm == 0: return vector
82+
return (arr / norm).tolist()
83+
84+
# Mode 2: Euclidean (L2 Norm) -> x_i / sqrt(sum x_i^2)
85+
elif mode == NORM_MODE_EUCLIDEAN:
86+
norm = np.linalg.norm(arr)
87+
if norm == 0: return vector
88+
return (arr / norm).tolist()
89+
90+
# Mode > 2: p-norm
91+
elif mode > 2:
92+
norm = np.sum(np.abs(arr) ** mode) ** (1.0 / mode)
93+
if norm == 0: return vector
94+
return (arr / norm).tolist()
95+
96+
return vector
97+
98+
def embed(
99+
self,
100+
input: Union[str, List[str], List[List[int]]],
101+
normalize: int = NORM_MODE_EUCLIDEAN,
102+
truncate: bool = True,
103+
separator: Optional[str] = None,
104+
return_count: bool = False,
105+
) -> Union[List[float], List[List[float]], Tuple[Any, int]]:
106+
107+
ctx = self._ctx.ctx
108+
n_batch = self.n_batch
109+
n_ctx = self._n_ctx
110+
n_ubatch = self.context_params.n_ubatch
111+
112+
print(f"n_batch={n_batch}, n_ubatch={n_ubatch}, n_ctx={n_ctx}")
113+
114+
# Determine if it is in Rerank mode
115+
try:
116+
current_pooling = self.pooling_type()
117+
except AttributeError:
118+
current_pooling = LLAMA_POOLING_TYPE_UNSPECIFIED
119+
is_rank = (current_pooling == LLAMA_POOLING_TYPE_RANK)
120+
logits_all = current_pooling == llama_cpp.LLAMA_POOLING_TYPE_NONE
121+
122+
# Determine the output dimension
123+
if is_rank:
124+
out_dim = llama_cpp.llama_model_n_cls_out(self._model.model)
125+
else:
126+
out_dim = self.n_embd()
127+
128+
if self.verbose:
129+
mode_str = "RANK (Score)" if is_rank else "EMBED (Vector)"
130+
print(f"LlamaEmbedding Debug: Mode={mode_str} | Output Dimension={out_dim}")
131+
132+
# Preprocess Input
133+
inputs: List[Union[str, List[int]]] = []
134+
is_single = False
135+
136+
if isinstance(input, str):
137+
if separator:
138+
inputs = input.split(separator)
139+
is_single = False
140+
else:
141+
inputs = [input]
142+
is_single = True
143+
else:
144+
inputs = input
145+
is_single = False
146+
147+
# Reset Context and Batch
148+
if self.verbose:
149+
llama_cpp.llama_perf_context_reset(ctx)
150+
self._batch.reset()
151+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
152+
153+
# Initialize State Variables
154+
results: List[Any] = []
155+
batch_seq_lens: List[int] = []
156+
total_tokens_processed = 0
157+
158+
# --- Decode Current Batch ---
159+
def _decode_batch():
160+
nonlocal batch_seq_lens
161+
if not batch_seq_lens: return
162+
163+
self._ctx.decode(self._batch)
164+
165+
for i in range(len(batch_seq_lens)):
166+
ptr = llama_cpp.llama_get_embeddings_seq(ctx, i)
167+
data = ptr[:out_dim]
168+
169+
if not is_rank:
170+
data = self._normalize_vector(data, normalize)
171+
172+
if is_rank and len(data) == 1:
173+
results.append(data[0])
174+
else:
175+
results.append(data)
176+
177+
self._batch.reset()
178+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
179+
batch_seq_lens = []
180+
181+
# Main Streaming Loop
182+
idx_in_batch = 0
183+
184+
for item in inputs:
185+
# Tokenize
186+
tokens: List[int] = []
187+
if isinstance(item, list) and (not item or isinstance(item[0], int)):
188+
tokens = item
189+
elif isinstance(item, str):
190+
tokens = self.tokenize(item.encode("utf-8"))
191+
else:
192+
raise ValueError("Input item must be str or List[int]")
193+
194+
# Truncate
195+
if truncate and len(tokens) > n_ctx:
196+
tokens = tokens[:n_ctx]
197+
198+
n_tokens = len(tokens)
199+
total_tokens_processed += n_tokens
200+
201+
if n_tokens == 0:
202+
results.append(0.0 if is_rank else [])
203+
continue
204+
205+
# Check Batch Capacity
206+
if (self._batch.n_tokens() + n_tokens > n_batch) or (idx_in_batch >= n_ubatch):
207+
_decode_batch()
208+
idx_in_batch = 0
209+
210+
# Add to Batch
211+
self._batch.add_sequence(tokens, idx_in_batch, logits_all=logits_all)
212+
batch_seq_lens.append(n_tokens)
213+
idx_in_batch += 1
214+
215+
# Process Remaining Items
216+
_decode_batch()
217+
218+
if self.verbose:
219+
llama_cpp.llama_perf_context_print(ctx)
220+
221+
final_result = results[0] if is_single else results
222+
223+
if return_count:
224+
return final_result, total_tokens_processed
225+
226+
return final_result
227+
228+
def rank(self, query: str, documents: List[str]) -> List[float]:
229+
"""
230+
Calculate relevance scores for a list of documents against a query using a Reranking model.
231+
232+
This method constructs a specific prompt structure ([BOS] Query [SEP] Doc [EOS])
233+
typically used by Cross-Encoders to estimate similarity.
234+
235+
Args:
236+
query: The search query string.
237+
documents: A list of candidate document strings to be scored.
238+
239+
Returns:
240+
A list of float scores, where higher values indicate greater relevance.
241+
"""
242+
if self.pooling_type() != LLAMA_POOLING_TYPE_RANK:
243+
raise ValueError(f"Model pooling_type is {self.pooling_type()}, but LLAMA_POOLING_TYPE_RANK is required.")
244+
245+
# Prepare Special Tokens
246+
sep_id = self.token_sep()
247+
if sep_id == -1: sep_id = self.token_eos()
248+
eos_id = self.token_eos()
249+
250+
# Pre-process Query
251+
q_tokens = self.tokenize(query.encode("utf-8"), add_bos=True, special=True)
252+
# Remove the automatically added EOS token from the query
253+
# because we need to append the separator and document tokens after it.
254+
if q_tokens and q_tokens[-1] == eos_id:
255+
q_tokens.pop()
256+
257+
# Construct Batch Inputs
258+
batch_inputs: List[List[int]] = []
259+
for doc in documents:
260+
d_tokens = self.tokenize(doc.encode("utf-8"), add_bos=False, special=True)
261+
full_seq = q_tokens + [sep_id] + d_tokens
262+
# Ensure the sequence ends with an EOS token to mark the end of inference.
263+
if not full_seq or full_seq[-1] != eos_id:
264+
full_seq.append(eos_id)
265+
batch_inputs.append(full_seq)
266+
267+
# We use NORM_MODE_NONE because rerankers output raw logits/scores, not vectors that need normalization.
268+
return self.embed(batch_inputs, normalize=NORM_MODE_NONE)
269+
270+
def create_embedding(
271+
self,
272+
input: Union[str, List[str]],
273+
model: Optional[str] = None,
274+
normalize: int = NORM_MODE_EUCLIDEAN,
275+
output_format: str = "json"
276+
) -> Union[Dict[str, Any], List[float], List[List[float]]]:
277+
"""
278+
High-level API compatible with OpenAI format.
279+
280+
Args:
281+
output_format:
282+
- 'json': OpenAI style dict (Default)
283+
- 'json+': OpenAI style dict + cosineSimilarity matrix
284+
- 'array': Raw python list (List[float] or List[List[float]])
285+
"""
286+
model_name = model if model is not None else self.model_path
287+
288+
# Normalize input to list
289+
inputs_list = [input] if isinstance(input, str) else input
290+
291+
# Generate Embeddings(and get token count)
292+
embeddings, token_count = self.embed(
293+
inputs_list,
294+
normalize=normalize,
295+
return_count=True
296+
)
297+
298+
if output_format == "array":
299+
return embeddings
300+
301+
# Structure the OpenAI-style response ('json' or 'json+')
302+
# Ensure embeddings is a list for iteration
303+
# (If input was single string, embeddings is List[float], wrap it for the loop)
304+
iter_embeddings = [embeddings] if isinstance(embeddings[0], float) else embeddings
305+
306+
data: List[Embedding] = [
307+
{
308+
"object": "embedding",
309+
"embedding": emb,
310+
"index": idx,
311+
}
312+
for idx, emb in enumerate(iter_embeddings)
313+
]
314+
315+
response = {
316+
"object": "list",
317+
"data": data,
318+
"model": model_name,
319+
"usage": {
320+
"prompt_tokens": token_count, # Input consumption
321+
"completion_tokens": 0, # The Embedding task does not generate text, so the value is 0.
322+
"total_tokens": token_count, # Total consumption = Input consumption + Output
323+
}
324+
}
325+
326+
# Calculate Cosine Similarity Matrix (Optimized via Numpy)
327+
# Only if output_format is 'json+' and we have vectors
328+
if output_format == "json+" and len(embeddings) > 1 and isinstance(embeddings[0], list):
329+
try:
330+
# Assuming embeddings are already L2 normalized if normalize=2
331+
mat = np.array(embeddings)
332+
333+
# Safety check: Force normalize if not already done, to ensure Cosine (not Dot Product)
334+
if normalize != NORM_MODE_EUCLIDEAN:
335+
norm = np.linalg.norm(mat, axis=1, keepdims=True)
336+
# Avoid division by zero
337+
norm[norm == 0] = 1e-10
338+
mat = mat / norm
339+
340+
# Matrix multiplication: A @ A.T
341+
sim_matrix = np.dot(mat, mat.T)
342+
response["cosineSimilarity"] = sim_matrix.tolist()
343+
except Exception as e:
344+
if self.verbose:
345+
print(f"Warning: Failed to calculate similarity matrix: {e}")
346+
347+
return response

0 commit comments

Comments
 (0)