Skip to content

Commit 1ae0a5b

Browse files
committed
[ENH] Add support for default space in create coll config
1 parent 5477bcf commit 1ae0a5b

17 files changed

+737
-63
lines changed

chromadb/api/async_fastapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ async def create_collection(
297297
) -> CollectionModel:
298298
"""Creates a collection"""
299299
config_json = (
300-
create_collection_configuration_to_json(configuration)
300+
create_collection_configuration_to_json(configuration, metadata)
301301
if configuration
302302
else None
303303
)

chromadb/api/collection_configuration.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from multiprocessing import cpu_count
1414
import warnings
15+
from chromadb.errors import InvalidSpaceError
1516

1617

1718
class HNSWConfiguration(TypedDict, total=False):
@@ -342,14 +343,16 @@ def load_create_collection_configuration_from_json(
342343

343344
def create_collection_configuration_to_json_str(
344345
config: CreateCollectionConfiguration,
346+
metadata: Optional[CollectionMetadata] = None,
345347
) -> str:
346348
"""Convert a CreateCollection configuration to a JSON-serializable string"""
347-
return json.dumps(create_collection_configuration_to_json(config))
349+
return json.dumps(create_collection_configuration_to_json(config, metadata))
348350

349351

350352
# TODO: make warnings prettier and add link to migration docs
351353
def create_collection_configuration_to_json(
352354
config: CreateCollectionConfiguration,
355+
metadata: Optional[CollectionMetadata] = None,
353356
) -> Dict[str, Any]:
354357
"""Convert a CreateCollection configuration to a JSON-serializable dict"""
355358
ef_config: Dict[str, Any] | None = None
@@ -383,12 +386,47 @@ def create_collection_configuration_to_json(
383386
if ef.is_legacy():
384387
ef_config = {"type": "legacy"}
385388
else:
389+
if hnsw_config is None and spann_config is None:
390+
if metadata is None or metadata.get("hnsw:space") is None:
391+
# this populates space from ef if not provided in either config
392+
hnsw_config = CreateHNSWConfiguration(space=ef.default_space())
393+
394+
# if hnsw config or spann config exists but space is not provided, populate it from ef
395+
if hnsw_config is not None and hnsw_config.get("space") is None:
396+
hnsw_config["space"] = ef.default_space()
397+
if spann_config is not None and spann_config.get("space") is None:
398+
spann_config["space"] = ef.default_space()
399+
400+
# Validate space compatibility with embedding function
401+
if hnsw_config is not None:
402+
if hnsw_config.get("space") not in ef.supported_spaces():
403+
raise InvalidSpaceError(
404+
f"space {hnsw_config.get('space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}"
405+
)
406+
if spann_config is not None:
407+
if spann_config.get("space") not in ef.supported_spaces():
408+
raise InvalidSpaceError(
409+
f"space {spann_config.get('space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}"
410+
)
411+
if (
412+
hnsw_config is None
413+
and spann_config is None
414+
and metadata is not None
415+
and metadata.get("hnsw:space") is not None
416+
):
417+
if metadata.get("hnsw:space") not in ef.supported_spaces():
418+
raise InvalidSpaceError(
419+
f"space {metadata.get('hnsw:space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}"
420+
)
421+
386422
ef_config = {
387423
"name": ef.name(),
388424
"type": "known",
389425
"config": ef.get_config(),
390426
}
391427
register_embedding_function(type(ef)) # type: ignore
428+
except InvalidSpaceError:
429+
raise
392430
except Exception as e:
393431
warnings.warn(
394432
f"legacy embedding function config: {e}",

chromadb/api/fastapi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ def create_collection(
258258
json={
259259
"name": name,
260260
"metadata": metadata,
261-
"configuration": create_collection_configuration_to_json(configuration)
261+
"configuration": create_collection_configuration_to_json(
262+
configuration, metadata
263+
)
262264
if configuration
263265
else None,
264266
"get_or_create": get_or_create,

chromadb/api/rust.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def create_collection(
219219
)
220220
if configuration:
221221
configuration_json_str = create_collection_configuration_to_json_str(
222-
configuration
222+
configuration, metadata
223223
)
224224
else:
225225
configuration_json_str = None

chromadb/api/segment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def create_collection(
236236
name=name,
237237
metadata=metadata,
238238
configuration_json=create_collection_configuration_to_json(
239-
configuration or CreateCollectionConfiguration()
239+
configuration or CreateCollectionConfiguration(), metadata
240240
),
241241
tenant=tenant,
242242
database=database,

chromadb/db/impl/grpc/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
create_collection_configuration_to_json_str,
77
UpdateCollectionConfiguration,
88
update_collection_configuration_to_json_str,
9+
CollectionMetadata,
910
)
1011
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger
1112
from chromadb.db.system import SysDB
@@ -327,7 +328,7 @@ def create_collection(
327328
id=id.hex,
328329
name=name,
329330
configuration_json_str=create_collection_configuration_to_json_str(
330-
configuration
331+
configuration, cast(CollectionMetadata, metadata)
331332
),
332333
metadata=to_proto_update_metadata(metadata) if metadata else None,
333334
dimension=dimension,

chromadb/db/mixins/sysdb.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
collection_configuration_to_json_str,
4444
overwrite_collection_configuration,
4545
update_collection_configuration_from_legacy_update_metadata,
46+
CollectionMetadata,
4647
)
4748

4849
logger = logging.getLogger(__name__)
@@ -311,7 +312,9 @@ def create_collection(
311312
collection = Collection(
312313
id=id,
313314
name=name,
314-
configuration_json=create_collection_configuration_to_json(configuration),
315+
configuration_json=create_collection_configuration_to_json(
316+
configuration, cast(CollectionMetadata, metadata)
317+
),
315318
metadata=metadata,
316319
dimension=dimension,
317320
tenant=tenant,
@@ -337,7 +340,9 @@ def create_collection(
337340
ParameterValue(self.uuid_to_db(collection["id"])),
338341
ParameterValue(collection["name"]),
339342
ParameterValue(
340-
create_collection_configuration_to_json_str(configuration)
343+
create_collection_configuration_to_json_str(
344+
configuration, cast(CollectionMetadata, metadata)
345+
)
341346
),
342347
ParameterValue(collection["dimension"]),
343348
# Get the database id for the database with the given name and tenant
@@ -941,7 +946,7 @@ def _insert_config_from_legacy_params(
941946
create_collection_config = CreateCollectionConfiguration()
942947
# Write the configuration into the database
943948
configuration_json_str = create_collection_configuration_to_json_str(
944-
create_collection_config
949+
create_collection_config, cast(CollectionMetadata, metadata)
945950
)
946951
q = (
947952
self.querybuilder()

chromadb/errors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,17 @@ def name(cls) -> str:
173173
return "QuotaError"
174174

175175

176+
class InvalidSpaceError(ChromaError):
177+
@overrides
178+
def code(self) -> int:
179+
return 400
180+
181+
@classmethod
182+
@overrides
183+
def name(cls) -> str:
184+
return "InvalidSpaceError"
185+
186+
176187
error_types: Dict[str, Type[ChromaError]] = {
177188
"InvalidDimension": InvalidDimensionException,
178189
"InvalidArgumentError": InvalidArgumentError,
@@ -189,6 +200,7 @@ def name(cls) -> str:
189200
"UniqueConstraintError": UniqueConstraintError,
190201
"QuotaError": QuotaError,
191202
"InternalError": InternalError,
203+
"InvalidSpaceError": InvalidSpaceError,
192204
# Catch-all for any other errors
193205
"ChromaError": ChromaError,
194206
}

0 commit comments

Comments
 (0)