Skip to content

Commit ed52a09

Browse files
committed
[ENH] recognize and flush new metadata keys to schema on local compaction
1 parent 0f1292e commit ed52a09

File tree

4 files changed

+153
-22
lines changed

4 files changed

+153
-22
lines changed

chromadb/test/api/test_schema_e2e.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ def test_schema_defaults_enable_indexed_operations(
363363
# Ensure underlying schema persisted across fetches
364364
reloaded = client.get_collection(collection.name)
365365
assert reloaded.schema is not None
366-
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()
366+
if not is_spann_disabled_mode:
367+
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()
367368

368369

369370
def test_get_or_create_and_get_collection_preserve_schema(
@@ -541,7 +542,8 @@ def test_schema_persistence_with_custom_overrides(
541542
reloaded_client = client_factories.create_client_from_system()
542543
reloaded_collection = reloaded_client.get_collection(name=collection.name)
543544
assert reloaded_collection.schema is not None
544-
assert reloaded_collection.schema.serialize_to_json() == expected_schema_json
545+
if not is_spann_disabled_mode:
546+
assert reloaded_collection.schema.serialize_to_json() == expected_schema_json
545547

546548
fetched = reloaded_collection.get(where={"title": "Schema Persistence"})
547549
assert set(fetched["ids"]) == {"persist-1"}
@@ -784,7 +786,6 @@ def _expect_disabled_error(operation: Callable[[], Any]) -> None:
784786
_expect_disabled_error(operation)
785787

786788

787-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
788789
def test_schema_discovers_new_keys_after_compaction(
789790
client_factories: "ClientFactories",
790791
) -> None:
@@ -802,7 +803,8 @@ def test_schema_discovers_new_keys_after_compaction(
802803

803804
collection.add(ids=ids, documents=documents, metadatas=metadatas)
804805

805-
wait_for_version_increase(client, collection.name, initial_version)
806+
if not is_spann_disabled_mode:
807+
wait_for_version_increase(client, collection.name, initial_version)
806808

807809
reloaded = client.get_collection(collection.name)
808810
assert reloaded.schema is not None
@@ -828,7 +830,8 @@ def test_schema_discovers_new_keys_after_compaction(
828830
metadatas=upsert_metadatas,
829831
)
830832

831-
wait_for_version_increase(client, collection.name, next_version)
833+
if not is_spann_disabled_mode:
834+
wait_for_version_increase(client, collection.name, next_version)
832835

833836
post_upsert = client.get_collection(collection.name)
834837
assert post_upsert.schema is not None
@@ -852,7 +855,6 @@ def test_schema_discovers_new_keys_after_compaction(
852855
assert "discover_upsert" in persisted.schema.keys
853856

854857

855-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
856858
def test_schema_rejects_conflicting_discoverable_key_types(
857859
client_factories: "ClientFactories",
858860
) -> None:
@@ -868,7 +870,8 @@ def test_schema_rejects_conflicting_discoverable_key_types(
868870
documents = [f"doc {i}" for i in range(251)]
869871
collection.add(ids=ids, documents=documents, metadatas=metadatas)
870872

871-
wait_for_version_increase(client, collection.name, initial_version)
873+
if not is_spann_disabled_mode:
874+
wait_for_version_increase(client, collection.name, initial_version)
872875

873876
collection.upsert(
874877
ids=["conflict-bad"],
@@ -1029,7 +1032,6 @@ def test_schema_embedding_configuration_enforced(
10291032
assert "sparse_auto" not in numeric_metadata
10301033

10311034

1032-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
10331035
def test_schema_precedence_for_overrides_discoverables_and_defaults(
10341036
client_factories: "ClientFactories",
10351037
) -> None:
@@ -1054,7 +1056,9 @@ def test_schema_precedence_for_overrides_discoverables_and_defaults(
10541056

10551057
initial_version = get_collection_version(client, collection.name)
10561058
collection.add(ids=ids, documents=documents, metadatas=metadatas)
1057-
wait_for_version_increase(client, collection.name, initial_version)
1059+
1060+
if not is_spann_disabled_mode:
1061+
wait_for_version_increase(client, collection.name, initial_version)
10581062

10591063
schema_state = client.get_collection(collection.name).schema
10601064
assert schema_state is not None

rust/log/src/local_compaction_manager.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ impl Handler<BackfillMessage> for LocalCompactionManager {
140140
.sysdb
141141
.get_collection_with_segments(message.collection_id)
142142
.await?;
143+
let schema_previously_persisted = collection_and_segments.collection.schema.is_some();
143144
collection_and_segments
144145
.collection
145146
.reconcile_schema_with_config(KnnIndex::Hnsw)?;
@@ -206,14 +207,31 @@ impl Handler<BackfillMessage> for LocalCompactionManager {
206207
.begin()
207208
.await
208209
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
209-
metadata_writer
210+
let apply_outcome = metadata_writer
210211
.apply_logs(
211212
mt_data_chunk,
212213
collection_and_segments.metadata_segment.id,
214+
if schema_previously_persisted {
215+
collection_and_segments.collection.schema.clone()
216+
} else {
217+
None
218+
},
213219
&mut *tx,
214220
)
215221
.await
216222
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
223+
if schema_previously_persisted {
224+
if let Some(updated_schema) = apply_outcome.schema_update {
225+
metadata_writer
226+
.update_collection_schema(
227+
collection_and_segments.collection.collection_id,
228+
&updated_schema,
229+
&mut *tx,
230+
)
231+
.await
232+
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
233+
}
234+
}
217235
tx.commit()
218236
.await
219237
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;

rust/segment/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ roaring = { workspace = true }
1111
sea-query = { workspace = true }
1212
sea-query-binder = { workspace = true, features = ["sqlx-sqlite"] }
1313
serde = { workspace = true }
14+
serde_json = { workspace = true }
1415
sqlx = { workspace = true }
1516
serde-pickle = "1.2.0"
1617
tantivy = { workspace = true }

rust/segment/src/sqlite_metadata.rs

Lines changed: 120 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ use chroma_error::{ChromaError, ErrorCodes};
77
use chroma_sqlite::{
88
db::SqliteDb,
99
helpers::{delete_metadata, update_metadata},
10-
table::{EmbeddingFulltextSearch, EmbeddingMetadata, Embeddings, MaxSeqId},
10+
table::{Collections, EmbeddingFulltextSearch, EmbeddingMetadata, Embeddings, MaxSeqId},
1111
};
1212
use chroma_types::{
1313
operator::{
1414
CountResult, Filter, GetResult, Limit, Projection, ProjectionOutput, ProjectionRecord, Scan,
1515
},
1616
plan::{Count, Get},
17-
BooleanOperator, Chunk, CompositeExpression, DocumentExpression, DocumentOperator, LogRecord,
18-
MetadataComparison, MetadataExpression, MetadataSetValue, MetadataValue,
19-
MetadataValueConversionError, Operation, OperationRecord, PrimitiveOperator, SegmentUuid,
20-
SetOperator, UpdateMetadataValue, Where, CHROMA_DOCUMENT_KEY,
17+
BooleanOperator, Chunk, CollectionUuid, CompositeExpression, DocumentExpression,
18+
DocumentOperator, LogRecord, MetadataComparison, MetadataExpression, MetadataSetValue,
19+
MetadataValue, MetadataValueConversionError, Operation, OperationRecord, PrimitiveOperator,
20+
Schema, SegmentUuid, SetOperator, UpdateMetadataValue, Where, CHROMA_DOCUMENT_KEY,
2121
};
2222
use sea_query::{
2323
Alias, DeleteStatement, Expr, ExprTrait, Func, InsertStatement, LikeExpr, OnConflict, Query,
@@ -41,6 +41,8 @@ pub enum SqliteMetadataError {
4141
SeaQuery(#[from] sea_query::error::Error),
4242
#[error(transparent)]
4343
Sqlx(#[from] sqlx::Error),
44+
#[error("Could not serialize schema: {0}")]
45+
SerializeSchema(#[from] serde_json::Error),
4446
}
4547

4648
impl ChromaError for SqliteMetadataError {
@@ -53,6 +55,11 @@ pub struct SqliteMetadataWriter {
5355
pub db: SqliteDb,
5456
}
5557

58+
pub struct ApplyLogsOutcome {
59+
pub schema_update: Option<Schema>,
60+
pub max_seq_id: Option<u64>,
61+
}
62+
5663
impl SqliteMetadataWriter {
5764
pub fn new(db: SqliteDb) -> Self {
5865
Self { db }
@@ -278,19 +285,69 @@ impl SqliteMetadataWriter {
278285
Ok(self.db.get_conn().begin().await?)
279286
}
280287

288+
pub async fn update_collection_schema<C>(
289+
&self,
290+
collection_id: CollectionUuid,
291+
schema: &Schema,
292+
tx: &mut C,
293+
) -> Result<(), SqliteMetadataError>
294+
where
295+
for<'connection> &'connection mut C: sqlx::Executor<'connection, Database = sqlx::Sqlite>,
296+
{
297+
let schema_str = serde_json::to_string(schema)?;
298+
let (sql, values) = Query::update()
299+
.table(Collections::Table)
300+
.value(Collections::SchemaStr, schema_str)
301+
.and_where(
302+
Expr::col((Collections::Table, Collections::Id)).eq(collection_id.to_string()),
303+
)
304+
.build_sqlx(SqliteQueryBuilder);
305+
sqlx::query_with(&sql, values).execute(&mut *tx).await?;
306+
Ok(())
307+
}
308+
309+
fn ensure_schema_for_update_value(
310+
schema: &mut Option<Schema>,
311+
key: &str,
312+
value: &UpdateMetadataValue,
313+
) -> bool {
314+
if key == CHROMA_DOCUMENT_KEY {
315+
return false;
316+
}
317+
match value {
318+
UpdateMetadataValue::None => false,
319+
_ => {
320+
if let Some(schema_mut) = schema.as_mut() {
321+
if let Ok(metadata_value) = MetadataValue::try_from(value) {
322+
return schema_mut
323+
.ensure_key_from_metadata(key, metadata_value.value_type());
324+
}
325+
}
326+
false
327+
}
328+
}
329+
}
330+
281331
pub async fn apply_logs<C>(
282332
&self,
283333
logs: Chunk<LogRecord>,
284334
segment_id: SegmentUuid,
335+
schema: Option<Schema>,
285336
tx: &mut C,
286-
) -> Result<(), SqliteMetadataError>
337+
) -> Result<ApplyLogsOutcome, SqliteMetadataError>
287338
where
288339
for<'connection> &'connection mut C: sqlx::Executor<'connection, Database = sqlx::Sqlite>,
289340
{
290341
if logs.is_empty() {
291-
return Ok(());
342+
return Ok(ApplyLogsOutcome {
343+
schema_update: None,
344+
max_seq_id: None,
345+
});
292346
}
347+
let mut schema = schema;
348+
let mut schema_modified = false;
293349
let mut max_seq_id = u64::MIN;
350+
let mut saw_log = false;
294351
for (
295352
LogRecord {
296353
log_offset,
@@ -307,6 +364,7 @@ impl SqliteMetadataWriter {
307364
) in logs.iter()
308365
{
309366
let log_offset_unsigned = (*log_offset).try_into()?;
367+
saw_log = true;
310368
max_seq_id = max_seq_id.max(log_offset_unsigned);
311369
let mut metadata_owned = metadata.clone();
312370
if let Some(doc) = document {
@@ -323,6 +381,11 @@ impl SqliteMetadataWriter {
323381
Self::add_record(tx, segment_id, log_offset_unsigned, id.clone()).await?
324382
{
325383
if let Some(meta) = metadata_owned {
384+
for (key, value) in meta.iter() {
385+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
386+
schema_modified = true;
387+
}
388+
}
326389
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
327390
}
328391

@@ -336,6 +399,11 @@ impl SqliteMetadataWriter {
336399
Self::update_record(tx, segment_id, log_offset_unsigned, id.clone()).await?
337400
{
338401
if let Some(meta) = metadata_owned {
402+
for (key, value) in meta.iter() {
403+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
404+
schema_modified = true;
405+
}
406+
}
339407
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
340408
}
341409

@@ -351,6 +419,11 @@ impl SqliteMetadataWriter {
351419
.await?;
352420

353421
if let Some(meta) = metadata_owned {
422+
for (key, value) in meta.iter() {
423+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
424+
schema_modified = true;
425+
}
426+
}
354427
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
355428
}
356429

@@ -371,7 +444,12 @@ impl SqliteMetadataWriter {
371444

372445
Self::upsert_max_seq_id(tx, segment_id, max_seq_id).await?;
373446

374-
Ok(())
447+
let max_seq_id = if saw_log { Some(max_seq_id) } else { None };
448+
449+
Ok(ApplyLogsOutcome {
450+
schema_update: if schema_modified { schema } else { None },
451+
max_seq_id,
452+
})
375453
}
376454
}
377455

@@ -910,7 +988,17 @@ mod tests {
910988
ref_seg.apply_logs(test_data.logs.clone(), metadata_seg_id);
911989
let mut tx = runtime.block_on(sqlite_seg_writer.begin()).expect("Should be able to start transaction");
912990
let data: Chunk<LogRecord> = Chunk::new(test_data.logs.clone().into());
913-
runtime.block_on(sqlite_seg_writer.apply_logs(data, metadata_seg_id, &mut *tx)).expect("Should be able to apply logs");
991+
runtime.block_on(sqlite_seg_writer.apply_logs(
992+
data,
993+
metadata_seg_id,
994+
test_data
995+
.collection_and_segments
996+
.collection
997+
.schema
998+
.clone(),
999+
&mut *tx,
1000+
))
1001+
.expect("Should be able to apply logs");
9141002
runtime.block_on(tx.commit()).expect("Should be able to commit log");
9151003

9161004
let sqlite_seg_reader = SqliteMetadataReader {
@@ -938,7 +1026,17 @@ mod tests {
9381026
ref_seg.apply_logs(test_data.logs.clone(), metadata_seg_id);
9391027
let mut tx = runtime.block_on(sqlite_seg_writer.begin()).expect("Should be able to start transaction");
9401028
let data: Chunk<LogRecord> = Chunk::new(test_data.logs.clone().into());
941-
runtime.block_on(sqlite_seg_writer.apply_logs(data, metadata_seg_id, &mut *tx)).expect("Should be able to apply logs");
1029+
runtime.block_on(sqlite_seg_writer.apply_logs(
1030+
data,
1031+
metadata_seg_id,
1032+
test_data
1033+
.collection_and_segments
1034+
.collection
1035+
.schema
1036+
.clone(),
1037+
&mut *tx,
1038+
))
1039+
.expect("Should be able to apply logs");
9421040
runtime.block_on(tx.commit()).expect("Should be able to commit log");
9431041

9441042
let sqlite_seg_reader = SqliteMetadataReader {
@@ -1020,7 +1118,12 @@ mod tests {
10201118
.expect("Should be able to start transaction");
10211119
let data: Chunk<LogRecord> = Chunk::new(logs.into());
10221120
sqlite_seg_writer
1023-
.apply_logs(data, metadata_seg_id, &mut *tx)
1121+
.apply_logs(
1122+
data,
1123+
metadata_seg_id,
1124+
collection_and_segments.collection.schema.clone(),
1125+
&mut *tx,
1126+
)
10241127
.await
10251128
.expect("Should be able to apply logs");
10261129
tx.commit().await.expect("Should be able to commit log");
@@ -1140,7 +1243,12 @@ mod tests {
11401243
.expect("Should be able to start transaction");
11411244
let data: Chunk<LogRecord> = Chunk::new(logs.into());
11421245
sqlite_seg_writer
1143-
.apply_logs(data, metadata_seg_id, &mut *tx)
1246+
.apply_logs(
1247+
data,
1248+
metadata_seg_id,
1249+
collection_and_segments.collection.schema.clone(),
1250+
&mut *tx,
1251+
)
11441252
.await
11451253
.expect("Should be able to apply logs");
11461254
tx.commit().await.expect("Should be able to commit log");

0 commit comments

Comments
 (0)