Skip to content

Commit fcdf9ac

Browse files
committed
Added n_ctx,n_batch,n_ubatch parameter into LlamaEmbedding
1 parent fb08475 commit fcdf9ac

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

llama_cpp/llama_embedding.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,23 @@ class LlamaEmbedding(Llama):
3636
using NumPy for optimal performance and compatibility with various vector databases.
3737
"""
3838

39-
def __init__(self, model_path: str, pooling_type: int = LLAMA_POOLING_TYPE_UNSPECIFIED, n_gpu_layers: int = 0, **kwargs):
39+
def __init__(
40+
self,
41+
model_path: str,
42+
n_ctx: int = 1024,
43+
n_batch: int = 512,
44+
n_ubatch: int = 512,
45+
pooling_type: int = LLAMA_POOLING_TYPE_UNSPECIFIED,
46+
n_gpu_layers: int = 0,
47+
**kwargs):
4048
"""
4149
Initialize the embedding model with enforced configuration.
4250
4351
Args:
4452
model_path: Path to the GGUF model file.
53+
n_ctx: Text context, 0 = from model
54+
n_batch: Prompt processing maximum batch size
55+
n_ubatch: Physical batch size
4556
pooling_type: The pooling strategy used by the model.
4657
- Use `LLAMA_POOLING_TYPE_RANK` (4) for Reranker models.
4758
- Use `LLAMA_POOLING_TYPE_UNSPECIFIED` (-1) to let the model metadata decide (for standard embeddings).
@@ -51,15 +62,16 @@ def __init__(self, model_path: str, pooling_type: int = LLAMA_POOLING_TYPE_UNSPE
5162
**kwargs: Additional arguments passed to the Llama base class (e.g., n_batch, n_ctx, verbose).
5263
"""
5364
kwargs["embedding"] = True
65+
kwargs["n_gpu_layers"] = n_gpu_layers
66+
kwargs["n_ctx"] = n_ctx
67+
kwargs["n_batch"] = n_batch
68+
kwargs["n_ubatch"] = n_ubatch
5469

5570
# Enable Unified KV Cache (Crucial for Batching)
5671
# This allows us to assign arbitrary seq_ids in a batch, enabling the parallel /
5772
# encoding of multiple unrelated documents without "invalid seq_id" errors.
5873
kwargs["kv_unified"] = True
5974

60-
# Number of model layers to offload to GPU.
61-
kwargs["n_gpu_layers"] = n_gpu_layers
62-
6375
# Set pooling type
6476
kwargs["pooling_type"] = pooling_type
6577

0 commit comments

Comments
 (0)