Skip to content

Commit b3cc7c7

Browse files
Merge #1104
1104: Add composite embedders and pooling for hf models r=brunoocasali a=nnethercott # Pull Request ## Related issue Fixes #1099 ## What does this PR do? - Adds a new embedder of type `CompositeEmbedder`, adds `pooling: PoolingOpt` to `HuggingFaceEmbedder`s - Adds a new pytest fixture enabling the experimental feature "compositeEmbedders" and a basic test to make the client can configure a composite embedder through a POST to `/indexes/{index_uid}/settings/embedders` ## PR checklist Please check if your PR fulfills the following requirements: - [x] Does this PR fix an existing issue, or have you listed the changes applied in the PR description (and why they are needed)? - [x] Have you read the contributing guidelines? - [x] Have you made sure that the title is accurate and descriptive of the changes? Thank you so much for contributing to Meilisearch! <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for composite embedders, enabling the use of different embedders for indexing and searching. - Introduced new pooling options for HuggingFace embedders to customize how token embeddings are aggregated. - **Tests** - Added tests to validate the configuration and behavior of composite embedders. - Introduced a fixture to toggle composite embedders feature during testing. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: nnethercott <nathaniel@deepomatic.com> Co-authored-by: Nate Nethercott <53127799+nnethercott@users.noreply.github.com>
2 parents 3f787ce + fd0d285 commit b3cc7c7

File tree

4 files changed

+136
-1
lines changed

4 files changed

+136
-1
lines changed

meilisearch/index.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from meilisearch.errors import version_error_hint_message
2626
from meilisearch.models.document import Document, DocumentsResults
2727
from meilisearch.models.embedders import (
28+
CompositeEmbedder,
2829
Embedders,
2930
EmbedderType,
3031
HuggingFaceEmbedder,
@@ -977,6 +978,8 @@ def get_settings(self) -> Dict[str, Any]:
977978
embedders[k] = HuggingFaceEmbedder(**v)
978979
elif v.get("source") == "rest":
979980
embedders[k] = RestEmbedder(**v)
981+
elif v.get("source") == "composite":
982+
embedders[k] = CompositeEmbedder(**v)
980983
else:
981984
embedders[k] = UserProvidedEmbedder(**v)
982985

@@ -1934,6 +1937,8 @@ def get_embedders(self) -> Embedders | None:
19341937
embedders[k] = OllamaEmbedder(**v)
19351938
elif source == "rest":
19361939
embedders[k] = RestEmbedder(**v)
1940+
elif source == "composite":
1941+
embedders[k] = CompositeEmbedder(**v)
19371942
elif source == "userProvided":
19381943
embedders[k] = UserProvidedEmbedder(**v)
19391944
else:
@@ -1977,6 +1982,8 @@ def update_embedders(self, body: Union[MutableMapping[str, Any], None]) -> TaskI
19771982
embedders[k] = OllamaEmbedder(**v)
19781983
elif source == "rest":
19791984
embedders[k] = RestEmbedder(**v)
1985+
elif source == "composite":
1986+
embedders[k] = CompositeEmbedder(**v)
19801987
elif source == "userProvided":
19811988
embedders[k] = UserProvidedEmbedder(**v)
19821989
else:

meilisearch/models/embedders.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from enum import Enum
34
from typing import Any, Dict, Optional, Union
45

56
from camel_converter.pydantic_base import CamelBase
@@ -20,6 +21,24 @@ class Distribution(CamelBase):
2021
sigma: float
2122

2223

24+
class PoolingType(str, Enum):
25+
"""Pooling strategies for HuggingFaceEmbedder.
26+
27+
Attributes
28+
----------
29+
USE_MODEL : str
30+
Use the model's default pooling strategy.
31+
FORCE_MEAN : str
32+
Force mean pooling over the token embeddings.
33+
FORCE_CLS : str
34+
Use the [CLS] token embedding as the sentence representation.
35+
"""
36+
37+
USE_MODEL = "useModel"
38+
FORCE_MEAN = "forceMean"
39+
FORCE_CLS = "forceCls"
40+
41+
2342
class OpenAiEmbedder(CamelBase):
2443
"""OpenAI embedder configuration.
2544
@@ -79,6 +98,8 @@ class HuggingFaceEmbedder(CamelBase):
7998
Describes the natural distribution of search results
8099
binary_quantized: Optional[bool]
81100
Once set to true, irreversibly converts all vector dimensions to 1-bit values
101+
pooling: Optional[PoolingType]
102+
Configures how individual tokens are merged into a single embedding
82103
"""
83104

84105
source: str = "huggingFace"
@@ -90,6 +111,7 @@ class HuggingFaceEmbedder(CamelBase):
90111
document_template_max_bytes: Optional[int] = None # Default to 400
91112
distribution: Optional[Distribution] = None
92113
binary_quantized: Optional[bool] = None
114+
pooling: Optional[PoolingType] = PoolingType.USE_MODEL
93115

94116

95117
class OllamaEmbedder(CamelBase):
@@ -191,13 +213,53 @@ class UserProvidedEmbedder(CamelBase):
191213
binary_quantized: Optional[bool] = None
192214

193215

216+
class CompositeEmbedder(CamelBase):
217+
"""Composite embedder configuration.
218+
219+
Parameters
220+
----------
221+
source: str
222+
The embedder source, must be "composite"
223+
indexing_embedder: Union[
224+
OpenAiEmbedder,
225+
HuggingFaceEmbedder,
226+
OllamaEmbedder,
227+
RestEmbedder,
228+
UserProvidedEmbedder,
229+
]
230+
search_embedder: Union[
231+
OpenAiEmbedder,
232+
HuggingFaceEmbedder,
233+
OllamaEmbedder,
234+
RestEmbedder,
235+
UserProvidedEmbedder,
236+
]"""
237+
238+
source: str = "composite"
239+
search_embedder: Union[
240+
OpenAiEmbedder,
241+
HuggingFaceEmbedder,
242+
OllamaEmbedder,
243+
RestEmbedder,
244+
UserProvidedEmbedder,
245+
]
246+
indexing_embedder: Union[
247+
OpenAiEmbedder,
248+
HuggingFaceEmbedder,
249+
OllamaEmbedder,
250+
RestEmbedder,
251+
UserProvidedEmbedder,
252+
]
253+
254+
194255
# Type alias for the embedder union type
195256
EmbedderType = Union[
196257
OpenAiEmbedder,
197258
HuggingFaceEmbedder,
198259
OllamaEmbedder,
199260
RestEmbedder,
200261
UserProvidedEmbedder,
262+
CompositeEmbedder,
201263
]
202264

203265

tests/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,20 @@ def new_embedders():
274274
"default": UserProvidedEmbedder(dimensions=1).model_dump(by_alias=True),
275275
"open_ai": OpenAiEmbedder().model_dump(by_alias=True),
276276
}
277+
278+
279+
@fixture
280+
def enable_composite_embedders():
281+
requests.patch(
282+
f"{common.BASE_URL}/experimental-features",
283+
headers={"Authorization": f"Bearer {common.MASTER_KEY}"},
284+
json={"compositeEmbedders": True},
285+
timeout=10,
286+
)
287+
yield
288+
requests.patch(
289+
f"{common.BASE_URL}/experimental-features",
290+
headers={"Authorization": f"Bearer {common.MASTER_KEY}"},
291+
json={"compositeEmbedders": False},
292+
timeout=10,
293+
)

tests/settings/test_settings_embedders.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
# pylint: disable=redefined-outer-name
22

3-
from meilisearch.models.embedders import OpenAiEmbedder, UserProvidedEmbedder
3+
import pytest
4+
5+
from meilisearch.models.embedders import (
6+
CompositeEmbedder,
7+
HuggingFaceEmbedder,
8+
OpenAiEmbedder,
9+
PoolingType,
10+
UserProvidedEmbedder,
11+
)
412

513

614
def test_get_default_embedders(empty_index):
@@ -97,6 +105,7 @@ def test_huggingface_embedder_format(empty_index):
97105
assert embedders.embedders["huggingface"].distribution.mean == 0.5
98106
assert embedders.embedders["huggingface"].distribution.sigma == 0.1
99107
assert embedders.embedders["huggingface"].binary_quantized is False
108+
assert embedders.embedders["huggingface"].pooling is PoolingType.USE_MODEL
100109

101110

102111
def test_ollama_embedder_format(empty_index):
@@ -183,3 +192,43 @@ def test_user_provided_embedder_format(empty_index):
183192
assert embedders.embedders["user_provided"].distribution.mean == 0.5
184193
assert embedders.embedders["user_provided"].distribution.sigma == 0.1
185194
assert embedders.embedders["user_provided"].binary_quantized is False
195+
196+
197+
@pytest.mark.usefixtures("enable_composite_embedders")
198+
def test_composite_embedder_format(empty_index):
199+
"""Tests that CompositeEmbedder embedder has the required fields and proper format."""
200+
index = empty_index()
201+
202+
embedder = HuggingFaceEmbedder().model_dump(by_alias=True, exclude_none=True)
203+
204+
# create composite embedder
205+
composite_embedder = {
206+
"composite": {
207+
"source": "composite",
208+
"searchEmbedder": embedder,
209+
"indexingEmbedder": embedder,
210+
}
211+
}
212+
213+
response = index.update_embedders(composite_embedder)
214+
update = index.wait_for_task(response.task_uid)
215+
embedders = index.get_embedders()
216+
assert update.status == "succeeded"
217+
218+
assert embedders.embedders["composite"].source == "composite"
219+
220+
# ensure serialization roundtrips nicely
221+
assert isinstance(embedders.embedders["composite"], CompositeEmbedder)
222+
assert isinstance(embedders.embedders["composite"].search_embedder, HuggingFaceEmbedder)
223+
assert isinstance(embedders.embedders["composite"].indexing_embedder, HuggingFaceEmbedder)
224+
225+
# ensure search_embedder has no document_template
226+
assert getattr(embedders.embedders["composite"].search_embedder, "document_template") is None
227+
assert (
228+
getattr(
229+
embedders.embedders["composite"].search_embedder,
230+
"document_template_max_bytes",
231+
)
232+
is None
233+
)
234+
assert getattr(embedders.embedders["composite"].indexing_embedder, "document_template")

0 commit comments

Comments
 (0)