Skip to content

Commit 8f4f55e

Browse files
committed
[ENH] recognize and flush new metadata keys to schema on local compaction
1 parent c26b0a8 commit 8f4f55e

File tree

4 files changed

+146
-22
lines changed

4 files changed

+146
-22
lines changed

chromadb/test/api/test_schema_e2e.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ def test_schema_defaults_enable_indexed_operations(
363363
# Ensure underlying schema persisted across fetches
364364
reloaded = client.get_collection(collection.name)
365365
assert reloaded.schema is not None
366-
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()
366+
if not is_spann_disabled_mode:
367+
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()
367368

368369

369370
def test_get_or_create_and_get_collection_preserve_schema(
@@ -541,7 +542,8 @@ def test_schema_persistence_with_custom_overrides(
541542
reloaded_client = client_factories.create_client_from_system()
542543
reloaded_collection = reloaded_client.get_collection(name=collection.name)
543544
assert reloaded_collection.schema is not None
544-
assert reloaded_collection.schema.serialize_to_json() == expected_schema_json
545+
if not is_spann_disabled_mode:
546+
assert reloaded_collection.schema.serialize_to_json() == expected_schema_json
545547

546548
fetched = reloaded_collection.get(where={"title": "Schema Persistence"})
547549
assert set(fetched["ids"]) == {"persist-1"}
@@ -784,7 +786,6 @@ def _expect_disabled_error(operation: Callable[[], Any]) -> None:
784786
_expect_disabled_error(operation)
785787

786788

787-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
788789
def test_schema_discovers_new_keys_after_compaction(
789790
client_factories: "ClientFactories",
790791
) -> None:
@@ -802,7 +803,8 @@ def test_schema_discovers_new_keys_after_compaction(
802803

803804
collection.add(ids=ids, documents=documents, metadatas=metadatas)
804805

805-
wait_for_version_increase(client, collection.name, initial_version)
806+
if not is_spann_disabled_mode:
807+
wait_for_version_increase(client, collection.name, initial_version)
806808

807809
reloaded = client.get_collection(collection.name)
808810
assert reloaded.schema is not None
@@ -828,7 +830,8 @@ def test_schema_discovers_new_keys_after_compaction(
828830
metadatas=upsert_metadatas,
829831
)
830832

831-
wait_for_version_increase(client, collection.name, next_version)
833+
if not is_spann_disabled_mode:
834+
wait_for_version_increase(client, collection.name, next_version)
832835

833836
post_upsert = client.get_collection(collection.name)
834837
assert post_upsert.schema is not None
@@ -852,7 +855,6 @@ def test_schema_discovers_new_keys_after_compaction(
852855
assert "discover_upsert" in persisted.schema.keys
853856

854857

855-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
856858
def test_schema_rejects_conflicting_discoverable_key_types(
857859
client_factories: "ClientFactories",
858860
) -> None:
@@ -868,7 +870,8 @@ def test_schema_rejects_conflicting_discoverable_key_types(
868870
documents = [f"doc {i}" for i in range(251)]
869871
collection.add(ids=ids, documents=documents, metadatas=metadatas)
870872

871-
wait_for_version_increase(client, collection.name, initial_version)
873+
if not is_spann_disabled_mode:
874+
wait_for_version_increase(client, collection.name, initial_version)
872875

873876
collection.upsert(
874877
ids=["conflict-bad"],
@@ -1029,7 +1032,6 @@ def test_schema_embedding_configuration_enforced(
10291032
assert "sparse_auto" not in numeric_metadata
10301033

10311034

1032-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
10331035
def test_schema_precedence_for_overrides_discoverables_and_defaults(
10341036
client_factories: "ClientFactories",
10351037
) -> None:
@@ -1054,7 +1056,9 @@ def test_schema_precedence_for_overrides_discoverables_and_defaults(
10541056

10551057
initial_version = get_collection_version(client, collection.name)
10561058
collection.add(ids=ids, documents=documents, metadatas=metadatas)
1057-
wait_for_version_increase(client, collection.name, initial_version)
1059+
1060+
if not is_spann_disabled_mode:
1061+
wait_for_version_increase(client, collection.name, initial_version)
10581062

10591063
schema_state = client.get_collection(collection.name).schema
10601064
assert schema_state is not None

rust/log/src/local_compaction_manager.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ impl Handler<BackfillMessage> for LocalCompactionManager {
140140
.sysdb
141141
.get_collection_with_segments(message.collection_id)
142142
.await?;
143+
let schema_previously_persisted = collection_and_segments.collection.schema.is_some();
143144
collection_and_segments
144145
.collection
145146
.reconcile_schema_with_config(KnnIndex::Hnsw)?;
@@ -206,14 +207,31 @@ impl Handler<BackfillMessage> for LocalCompactionManager {
206207
.begin()
207208
.await
208209
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
209-
metadata_writer
210+
let apply_outcome = metadata_writer
210211
.apply_logs(
211212
mt_data_chunk,
212213
collection_and_segments.metadata_segment.id,
214+
if schema_previously_persisted {
215+
collection_and_segments.collection.schema.clone()
216+
} else {
217+
None
218+
},
213219
&mut *tx,
214220
)
215221
.await
216222
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
223+
if schema_previously_persisted {
224+
if let Some(updated_schema) = apply_outcome.schema_update {
225+
metadata_writer
226+
.update_collection_schema(
227+
collection_and_segments.collection.collection_id,
228+
&updated_schema,
229+
&mut *tx,
230+
)
231+
.await
232+
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
233+
}
234+
}
217235
tx.commit()
218236
.await
219237
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;

rust/segment/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ roaring = { workspace = true }
1111
sea-query = { workspace = true }
1212
sea-query-binder = { workspace = true, features = ["sqlx-sqlite"] }
1313
serde = { workspace = true }
14+
serde_json = { workspace = true }
1415
sqlx = { workspace = true }
1516
serde-pickle = "1.2.0"
1617
tantivy = { workspace = true }

rust/segment/src/sqlite_metadata.rs

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ use chroma_error::{ChromaError, ErrorCodes};
77
use chroma_sqlite::{
88
db::SqliteDb,
99
helpers::{delete_metadata, update_metadata},
10-
table::{EmbeddingFulltextSearch, EmbeddingMetadata, Embeddings, MaxSeqId},
10+
table::{Collections, EmbeddingFulltextSearch, EmbeddingMetadata, Embeddings, MaxSeqId},
1111
};
1212
use chroma_types::{
1313
operator::{
1414
CountResult, Filter, GetResult, Limit, Projection, ProjectionOutput, ProjectionRecord, Scan,
1515
},
1616
plan::{Count, Get},
17-
BooleanOperator, Chunk, CompositeExpression, DocumentExpression, DocumentOperator, LogRecord,
18-
MetadataComparison, MetadataExpression, MetadataSetValue, MetadataValue,
19-
MetadataValueConversionError, Operation, OperationRecord, PrimitiveOperator, SegmentUuid,
20-
SetOperator, UpdateMetadataValue, Where, CHROMA_DOCUMENT_KEY,
17+
BooleanOperator, Chunk, CollectionUuid, CompositeExpression, DocumentExpression,
18+
DocumentOperator, LogRecord, MetadataComparison, MetadataExpression, MetadataSetValue,
19+
MetadataValue, MetadataValueConversionError, Operation, OperationRecord, PrimitiveOperator,
20+
Schema, SegmentUuid, SetOperator, UpdateMetadataValue, Where, CHROMA_DOCUMENT_KEY,
2121
};
2222
use sea_query::{
2323
Alias, DeleteStatement, Expr, ExprTrait, Func, InsertStatement, LikeExpr, OnConflict, Query,
@@ -41,6 +41,8 @@ pub enum SqliteMetadataError {
4141
SeaQuery(#[from] sea_query::error::Error),
4242
#[error(transparent)]
4343
Sqlx(#[from] sqlx::Error),
44+
#[error("Could not serialize schema: {0}")]
45+
SerializeSchema(#[from] serde_json::Error),
4446
}
4547

4648
impl ChromaError for SqliteMetadataError {
@@ -53,6 +55,10 @@ pub struct SqliteMetadataWriter {
5355
pub db: SqliteDb,
5456
}
5557

58+
pub struct ApplyLogsOutcome {
59+
pub schema_update: Option<Schema>,
60+
}
61+
5662
impl SqliteMetadataWriter {
5763
pub fn new(db: SqliteDb) -> Self {
5864
Self { db }
@@ -278,18 +284,66 @@ impl SqliteMetadataWriter {
278284
Ok(self.db.get_conn().begin().await?)
279285
}
280286

287+
pub async fn update_collection_schema<C>(
288+
&self,
289+
collection_id: CollectionUuid,
290+
schema: &Schema,
291+
tx: &mut C,
292+
) -> Result<(), SqliteMetadataError>
293+
where
294+
for<'connection> &'connection mut C: sqlx::Executor<'connection, Database = sqlx::Sqlite>,
295+
{
296+
let schema_str = serde_json::to_string(schema)?;
297+
let (sql, values) = Query::update()
298+
.table(Collections::Table)
299+
.value(Collections::SchemaStr, schema_str)
300+
.and_where(
301+
Expr::col((Collections::Table, Collections::Id)).eq(collection_id.to_string()),
302+
)
303+
.build_sqlx(SqliteQueryBuilder);
304+
sqlx::query_with(&sql, values).execute(&mut *tx).await?;
305+
Ok(())
306+
}
307+
308+
fn ensure_schema_for_update_value(
309+
schema: &mut Option<Schema>,
310+
key: &str,
311+
value: &UpdateMetadataValue,
312+
) -> bool {
313+
if key == CHROMA_DOCUMENT_KEY {
314+
return false;
315+
}
316+
match value {
317+
UpdateMetadataValue::None => false,
318+
_ => {
319+
if let Some(schema_mut) = schema.as_mut() {
320+
if let Ok(metadata_value) = MetadataValue::try_from(value) {
321+
return schema_mut
322+
.ensure_key_from_metadata(key, metadata_value.value_type());
323+
}
324+
}
325+
false
326+
}
327+
}
328+
}
329+
281330
pub async fn apply_logs<C>(
282331
&self,
283332
logs: Chunk<LogRecord>,
284333
segment_id: SegmentUuid,
334+
schema: Option<Schema>,
285335
tx: &mut C,
286-
) -> Result<(), SqliteMetadataError>
336+
) -> Result<ApplyLogsOutcome, SqliteMetadataError>
287337
where
288338
for<'connection> &'connection mut C: sqlx::Executor<'connection, Database = sqlx::Sqlite>,
289339
{
290340
if logs.is_empty() {
291-
return Ok(());
341+
return Ok(ApplyLogsOutcome {
342+
schema_update: None,
343+
});
292344
}
345+
let mut schema = schema;
346+
let mut schema_modified = false;
293347
let mut max_seq_id = u64::MIN;
294348
for (
295349
LogRecord {
@@ -323,6 +377,11 @@ impl SqliteMetadataWriter {
323377
Self::add_record(tx, segment_id, log_offset_unsigned, id.clone()).await?
324378
{
325379
if let Some(meta) = metadata_owned {
380+
for (key, value) in meta.iter() {
381+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
382+
schema_modified = true;
383+
}
384+
}
326385
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
327386
}
328387

@@ -336,6 +395,11 @@ impl SqliteMetadataWriter {
336395
Self::update_record(tx, segment_id, log_offset_unsigned, id.clone()).await?
337396
{
338397
if let Some(meta) = metadata_owned {
398+
for (key, value) in meta.iter() {
399+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
400+
schema_modified = true;
401+
}
402+
}
339403
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
340404
}
341405

@@ -351,6 +415,11 @@ impl SqliteMetadataWriter {
351415
.await?;
352416

353417
if let Some(meta) = metadata_owned {
418+
for (key, value) in meta.iter() {
419+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
420+
schema_modified = true;
421+
}
422+
}
354423
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
355424
}
356425

@@ -371,7 +440,9 @@ impl SqliteMetadataWriter {
371440

372441
Self::upsert_max_seq_id(tx, segment_id, max_seq_id).await?;
373442

374-
Ok(())
443+
Ok(ApplyLogsOutcome {
444+
schema_update: if schema_modified { schema } else { None },
445+
})
375446
}
376447
}
377448

@@ -910,7 +981,17 @@ mod tests {
910981
ref_seg.apply_logs(test_data.logs.clone(), metadata_seg_id);
911982
let mut tx = runtime.block_on(sqlite_seg_writer.begin()).expect("Should be able to start transaction");
912983
let data: Chunk<LogRecord> = Chunk::new(test_data.logs.clone().into());
913-
runtime.block_on(sqlite_seg_writer.apply_logs(data, metadata_seg_id, &mut *tx)).expect("Should be able to apply logs");
984+
runtime.block_on(sqlite_seg_writer.apply_logs(
985+
data,
986+
metadata_seg_id,
987+
test_data
988+
.collection_and_segments
989+
.collection
990+
.schema
991+
.clone(),
992+
&mut *tx,
993+
))
994+
.expect("Should be able to apply logs");
914995
runtime.block_on(tx.commit()).expect("Should be able to commit log");
915996

916997
let sqlite_seg_reader = SqliteMetadataReader {
@@ -938,7 +1019,17 @@ mod tests {
9381019
ref_seg.apply_logs(test_data.logs.clone(), metadata_seg_id);
9391020
let mut tx = runtime.block_on(sqlite_seg_writer.begin()).expect("Should be able to start transaction");
9401021
let data: Chunk<LogRecord> = Chunk::new(test_data.logs.clone().into());
941-
runtime.block_on(sqlite_seg_writer.apply_logs(data, metadata_seg_id, &mut *tx)).expect("Should be able to apply logs");
1022+
runtime.block_on(sqlite_seg_writer.apply_logs(
1023+
data,
1024+
metadata_seg_id,
1025+
test_data
1026+
.collection_and_segments
1027+
.collection
1028+
.schema
1029+
.clone(),
1030+
&mut *tx,
1031+
))
1032+
.expect("Should be able to apply logs");
9421033
runtime.block_on(tx.commit()).expect("Should be able to commit log");
9431034

9441035
let sqlite_seg_reader = SqliteMetadataReader {
@@ -1020,7 +1111,12 @@ mod tests {
10201111
.expect("Should be able to start transaction");
10211112
let data: Chunk<LogRecord> = Chunk::new(logs.into());
10221113
sqlite_seg_writer
1023-
.apply_logs(data, metadata_seg_id, &mut *tx)
1114+
.apply_logs(
1115+
data,
1116+
metadata_seg_id,
1117+
collection_and_segments.collection.schema.clone(),
1118+
&mut *tx,
1119+
)
10241120
.await
10251121
.expect("Should be able to apply logs");
10261122
tx.commit().await.expect("Should be able to commit log");
@@ -1140,7 +1236,12 @@ mod tests {
11401236
.expect("Should be able to start transaction");
11411237
let data: Chunk<LogRecord> = Chunk::new(logs.into());
11421238
sqlite_seg_writer
1143-
.apply_logs(data, metadata_seg_id, &mut *tx)
1239+
.apply_logs(
1240+
data,
1241+
metadata_seg_id,
1242+
collection_and_segments.collection.schema.clone(),
1243+
&mut *tx,
1244+
)
11441245
.await
11451246
.expect("Should be able to apply logs");
11461247
tx.commit().await.expect("Should be able to commit log");

0 commit comments

Comments
 (0)