-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag_engine.py
More file actions
2508 lines (2161 loc) · 106 KB
/
Copy pathrag_engine.py
File metadata and controls
2508 lines (2161 loc) · 106 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
RAG engine for session transcripts.
Embeds conversation turns via mlx-embeddings. Stores vectors in Milvus — either
a remote Standalone instance (via SESSIONFLOW_MILVUS_URI) or embedded Milvus Lite
at ~/.sessionflow/milvus.db (fallback).
Full-text search via SQLite FTS5 sidecar for hybrid search (vector + keyword).
Results merged with Reciprocal Rank Fusion (RRF).
Each turn is tagged with a project_root field, enabling per-project or cross-project search.
Supports multiple embedding models via SESSIONFLOW_MODEL env var (default: embeddinggemma).
"""
import hashlib
import heapq
import json
import math
import os
import re
from datetime import datetime, timedelta, timezone
from pathlib import Path
# Block all HuggingFace network access at runtime.
# Models must be pre-downloaded via setup.sh / download-model.sh.
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
from pymilvus.exceptions import MilvusException
from contextlib import contextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator, List, Dict, NamedTuple, Optional
import asyncio
import logging
import sys
import threading
import time
from fts_hybrid import FTSIndex, fts_backfill_required, rrf_merge
from embedding_control import (
EmbeddingIdentity,
get_embedding_budget,
_env_float,
_env_int,
)
from provider_adapters import (
LEGAL_PROVIDERS,
LEGAL_SORT_BY,
LEGAL_SOURCE_KINDS,
default_provider_metadata,
is_valid_issue_token,
)
import secret_redaction
RECENCY_WEIGHT_DEFAULT = 0.3
RECENCY_DECAY_DAYS_DEFAULT = 7
MISSING_TIMESTAMP_RECENCY = 0.5
_RANKING_SCRATCH_KEYS = ("_rrf_score", "_score", "_semantic_score", "_recency_score")
logger = logging.getLogger("sessionflow.milvus")
# Issue-ID extraction (SESF-25): technical-standard prefixes that match the
# issue-token regex but are never issue references. Dropped during extraction.
_ISSUE_ID_PREFIX_DENYLIST = frozenset(
{"UTF", "SHA", "HTTP", "HTTPS", "ISO", "RFC", "IPV", "MD", "BASE"}
)
# Milvus VARCHAR field length the extracted ids are stored in.
_ISSUE_ID_FIELD_MAX = 4096
# Default number of timeline entries returned by get_issue_timeline (SESF-25/26).
DEFAULT_TIMELINE_LIMIT = 50
# Generous FTS fetch window for the timeline fallback so the chronological slice
# isn't biased by BM25 rank (the older matching turn may be outside the top-N).
_TIMELINE_FTS_FETCH_CAP = 500
# Observability threshold: warn when one issue matches an unexpectedly large
# structured set. Memory is bounded (SESF-34: rows stream through `_OldestN`
# rather than draining into a list), so this is now a scan-cost heads-up — narrow
# with date_from/date_to to shrink the iterator window. A server-side cap is still
# impossible: Milvus query_iterator order is undefined and it can't sort by a
# scalar, so any first-N truncation would drop arbitrary (not newest) rows.
_TIMELINE_ROWS_WARN = 50000
# Shared Milvus output_fields for vector search and recency listing. Includes
# ``issue_ids`` so SESF-25 issue tags propagate through ``_row_to_result``.
_SEARCH_OUTPUT_FIELDS = [
"document", "doc_id", "session_id", "transcript_file",
"turn_index", "timestamp", "git_branch", "chunk_type",
"project_root", "logical_session_id", "provider",
"source_kind", "source_class", "source_id", "source_path",
"issue_ids",
]
class FtsBackfillTransientError(Exception):
"""Raised by backfill_fts on a transient Milvus / schema-drift failure (SESF-38).
Signals the FTS heal worker to retry on a later cadence tick rather than treat the
failure as terminal. The originating exception is preserved as __cause__.
"""
# Serializes FTS heal runs: a non-blocking acquire in backfill_fts ensures a second
# heal attempt (e.g. an overlapping cadence tick) returns early instead of double-work.
_fts_backfill_lock = threading.Lock()
def _extract_issue_ids(text: str) -> str:
"""Extract issue references (e.g. ``SESF-25``) from a turn's text.
Matches the issue-token regex ``\\b[A-Z][A-Z0-9]+-\\d+\\b`` case-insensitively
(canonicalizing matches to upper case), drops technical-standard prefixes in
``_ISSUE_ID_PREFIX_DENYLIST`` (UTF-8, SHA-256, HTTP-2, ...), and
deduplicates the survivors in first-seen order.
Args:
text: Raw turn text to scan.
Returns:
A delimiter-wrapped, comma-joined string of issue ids with a leading
and trailing comma (e.g. ``",SESF-25,SESF-26,"``), or ``""`` when no
issue token is found. The result is capped to ``_ISSUE_ID_FIELD_MAX``
characters so a Milvus insert cannot overflow the storage field; if the
next id would exceed the cap, extraction stops and logs one warning.
"""
if not text or not isinstance(text, str):
return ""
seen: List[str] = []
seen_set: set[str] = set()
# Match case-insensitively and canonicalize only the matched tokens, rather
# than allocating an uppercased copy of the whole (possibly large) turn text.
for match in re.finditer(r"\b[A-Z][A-Z0-9]+-\d+\b", text, re.IGNORECASE):
token = match.group(0).upper()
prefix = token.split("-", 1)[0]
if prefix in _ISSUE_ID_PREFIX_DENYLIST:
continue
if token in seen_set:
continue
seen_set.add(token)
seen.append(token)
if not seen:
return ""
result = ","
for token in seen:
candidate = result + token + ","
if len(candidate) > _ISSUE_ID_FIELD_MAX:
logger.warning(
"issue-id list truncated at %d chars (field cap %d)",
len(result),
_ISSUE_ID_FIELD_MAX,
)
break
result = candidate
# Guard the pathological case where even the first token exceeds the cap:
# ``result`` would still be the bare delimiter ",", which is neither "" nor a
# valid comma-wrapped list. Normalize to "".
return result if len(result) > 1 else ""
def _truncate_utf8(text: str, max_bytes: int) -> str:
"""Truncate `text` so its UTF-8 encoding fits in `max_bytes`.
Milvus VARCHAR caps are measured in bytes, not Python characters; a naive
`text[:max_bytes]` slice happily produces a 65k-character string that
serializes to >65k bytes once any multibyte codepoint is present.
"""
encoded = text.encode("utf-8")
if len(encoded) <= max_bytes:
return text
return encoded[:max_bytes].decode("utf-8", errors="ignore")
def _is_remote_uri(uri: str) -> bool:
"""True when uri points to a remote Milvus Standalone (http:// or https://)."""
return uri.startswith("http://") or uri.startswith("https://")
# --- Model registry ---
_MODEL_REGISTRY = {
"modernbert": {
"model_id": "nomic-ai/modernbert-embed-base",
"embed_dim": 768,
"max_tokens": 8192,
"search_prefix": "search_query: ",
"document_prefix": "search_document: ",
"cache_subdir": "models--nomic-ai--modernbert-embed-base",
},
"embeddinggemma": {
"model_id": "mlx-community/embeddinggemma-300m-bf16",
"embed_dim": 768,
"max_tokens": 2048,
"search_prefix": "task: search result | query: ",
"document_prefix": "title: none | text: ",
"cache_subdir": "models--mlx-community--embeddinggemma-300m-bf16",
},
}
_MODEL_NAME = os.getenv("SESSIONFLOW_MODEL", "embeddinggemma").lower()
if _MODEL_NAME not in _MODEL_REGISTRY:
raise ValueError(
f"Unknown model '{_MODEL_NAME}'. "
f"Valid options: {', '.join(_MODEL_REGISTRY.keys())}"
)
_MODEL_CFG = _MODEL_REGISTRY[_MODEL_NAME]
_EMBED_DIM = _MODEL_CFG["embed_dim"]
_MODEL_ID = _MODEL_CFG["model_id"]
_MODEL_CACHE = Path.home() / ".cache/huggingface/hub" / _MODEL_CFG["cache_subdir"]
_SEARCH_PREFIX = _MODEL_CFG["search_prefix"]
_DOCUMENT_PREFIX = _MODEL_CFG["document_prefix"]
COLLECTION_NAME = "sessions"
# --- Model identity check ---
_IDENTITY_FILE = Path.home() / ".sessionflow" / "model_identity.json"
def _check_model_identity(db_path: Optional[str] = None):
"""Verify that the active model matches what was used to build the index.
On first run, stamps model_identity.json. On subsequent runs, if the stored
model differs and the index has data, raises an error to prevent mixing
incompatible vectors.
"""
_IDENTITY_FILE.parent.mkdir(parents=True, exist_ok=True)
if _IDENTITY_FILE.exists():
stored = json.loads(_IDENTITY_FILE.read_text())
stored_model = stored.get("model_name", "")
if stored_model and stored_model != _MODEL_NAME:
# Check if the index actually has data before raising
has_data = False
if db_path:
try:
client = MilvusClient(db_path)
if client.has_collection(COLLECTION_NAME):
count = client.query(
collection_name=COLLECTION_NAME,
filter="",
limit=1,
output_fields=["id"],
)
has_data = len(count) > 0
client.close()
except Exception:
pass
if has_data:
raise RuntimeError(
f"Model mismatch: index was built with '{stored_model}' but "
f"SESSIONFLOW_MODEL is '{_MODEL_NAME}'. "
f"Run cleanup.py reset or clear the index before switching models."
)
# Index is empty — safe to overwrite the stamp
# Stamp current model
_IDENTITY_FILE.write_text(json.dumps({"model_name": _MODEL_NAME}))
def get_model_name() -> str:
"""Return the active model's short name (e.g. 'modernbert', 'embeddinggemma')."""
return _MODEL_NAME
def get_embedding_identity() -> Dict[str, object]:
"""Return the active local embedding identity for health/status output."""
try:
identity = EmbeddingIdentity.current_local()
except ValueError as exc:
logger.warning("Invalid embedding identity: %s", exc)
return {
"embedding_provider": "unknown",
"model_name": "unknown",
"dimension": None,
"collection_name": COLLECTION_NAME,
"created_at": "",
"error": str(exc),
}
return {
"embedding_provider": identity.embedding_provider,
"model_name": identity.model_name,
"dimension": identity.dimension,
"collection_name": identity.collection_name,
"created_at": identity.created_at,
}
_mlx_model = None
_mlx_tokenizer = None
_mlx_load = None
_mlx_generate = None
_mlx_core = None
def _load_mlx_runtime():
"""Import MLX lazily so non-embedding tests/status paths cannot crash at import time."""
global _mlx_load, _mlx_generate, _mlx_core
if _mlx_load is None or _mlx_generate is None or _mlx_core is None:
from mlx_embeddings.utils import load as mlx_load, generate as mlx_generate
import mlx.core as mx
_mlx_load = mlx_load
_mlx_generate = mlx_generate
_mlx_core = mx
return _mlx_load, _mlx_generate, _mlx_core
def get_model():
"""Get or load the MLX embedding model (one-time load)."""
global _mlx_model, _mlx_tokenizer
if _mlx_model is not None:
return _mlx_model, _mlx_tokenizer
if not _MODEL_CACHE.exists():
raise RuntimeError(
f"Embedding model not cached at {_MODEL_CACHE}. "
f"Run ./setup.sh or ./download-model.sh to download it."
)
print(f"Loading {_MODEL_ID} via mlx-embeddings...", file=sys.stderr)
mlx_load, _, _ = _load_mlx_runtime()
_mlx_model, _mlx_tokenizer = mlx_load(_MODEL_ID)
print(f"{_MODEL_ID} ready ({_EMBED_DIM} dims, {_MODEL_CFG['max_tokens']} token context)", file=sys.stderr)
return _mlx_model, _mlx_tokenizer
def _needs_input_remap() -> bool:
"""Check if the model's __call__ uses 'inputs' instead of 'input_ids'.
Works around mlx-embeddings gemma3_text models where __call__ expects
'inputs' but the tokenizer returns 'input_ids'.
"""
return "gemma" in _MODEL_NAME
def embed_texts(texts: List[str], is_query: bool = False) -> List[List[float]]:
"""Embed texts using the configured model. Adds model-specific prefix."""
model, tokenizer = get_model()
_, mlx_generate, mx = _load_mlx_runtime()
prefix = _SEARCH_PREFIX if is_query else _DOCUMENT_PREFIX
prefixed = [prefix + t for t in texts]
if _needs_input_remap():
# gemma3_text models expect (inputs, attention_mask) not (input_ids, ...)
encoded = tokenizer.batch_encode_plus(
prefixed, return_tensors="mlx", padding=True,
truncation=True, max_length=_MODEL_CFG["max_tokens"],
)
output = model(encoded["input_ids"], attention_mask=encoded.get("attention_mask"))
else:
output = mlx_generate(model, tokenizer, texts=prefixed,
max_length=_MODEL_CFG["max_tokens"])
embeddings = output.text_embeds.tolist()
mx.clear_cache()
return embeddings
# --- Milvus client management ---
_persistent_clients: Dict[str, MilvusClient] = {}
_fts = FTSIndex("turns_fts", [
"session_id", "git_branch", "turn_index", "timestamp", "chunk_type",
"project_root", "logical_session_id", "provider", "source_kind",
"source_class", "source_id", "source_path", "issue_ids",
])
_write_lock: Optional[asyncio.Lock] = None
_embed_semaphore: Optional[asyncio.Semaphore] = None
# Dedicated single-worker executor so every MLX/Metal call runs on the same OS
# thread. The asyncio semaphore already serializes calls in time, but the
# default executor can rotate workers between calls — and MLX command-buffer
# state is not safe to migrate across threads. See SESF-8.
_embed_executor: Optional[ThreadPoolExecutor] = None
_server_mode = False
def init_server_mode(db_path: Optional[str] = None):
"""Initialize async concurrency primitives for HTTP server mode."""
global _write_lock, _embed_semaphore, _embed_executor, _server_mode
_check_model_identity(db_path=db_path)
_write_lock = asyncio.Lock()
_embed_semaphore = asyncio.Semaphore(1)
_embed_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx-embed")
_server_mode = True
_fts.set_server_mode(True)
print(f"Server mode initialized (model: {_MODEL_NAME})", file=sys.stderr)
def close_server_mode():
"""Close all persistent clients (Milvus + FTS) and reset server mode."""
global _write_lock, _embed_semaphore, _embed_executor, _server_mode
for path, client in list(_persistent_clients.items()):
try:
client.close()
logger.info("Closed Milvus client: %s", path)
except Exception as e:
logger.warning("Error closing Milvus client %s: %s", path, _scrub_exception(e))
_persistent_clients.clear()
_fts.close_all()
# Nil the semaphore and lock FIRST so any coroutine that wakes up while we
# are shutting down sees None and takes the CLI fallback path instead of
# trying to enqueue work onto a torn-down executor.
_embed_semaphore = None
_write_lock = None
if _embed_executor is not None:
_embed_executor.shutdown(wait=True)
_embed_executor = None
_server_mode = False
def _get_persistent_client(db_path: str) -> MilvusClient:
"""Get or create a persistent client for the given DB path.
On failure, evicts the stale client and retries once."""
if db_path in _persistent_clients:
try:
_persistent_clients[db_path].has_collection(COLLECTION_NAME)
return _persistent_clients[db_path]
except Exception as e:
logger.warning("Stale Milvus client for %s: %s — reconnecting", db_path, _scrub_exception(e))
try:
_persistent_clients[db_path].close()
except Exception:
pass
del _persistent_clients[db_path]
if not _is_remote_uri(db_path):
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
try:
if _is_remote_uri(db_path):
# Remote Milvus Standalone — default gRPC settings are fine.
_persistent_clients[db_path] = MilvusClient(db_path)
else:
# Milvus Lite — increase gRPC keepalive to 120s to prevent
# GOAWAY/ENHANCE_YOUR_CALM (Lite rejects default 10s as too_many_pings).
_persistent_clients[db_path] = MilvusClient(
db_path,
grpc_options={
"grpc.keepalive_time_ms": 120_000,
"grpc.keepalive_timeout_ms": 20_000,
},
)
logger.info("Opened client: %s", db_path)
except Exception as e:
logger.error("Failed to connect to Milvus at %s: %s", db_path, _scrub_exception(e))
raise
return _persistent_clients[db_path]
def _resolve_db_path(db_path: Optional[str]) -> str:
if not db_path:
raise ValueError("db_path is required. Global index is at ~/.sessionflow/milvus.db")
return db_path
def _expected_schema_fields() -> List[FieldSchema]:
"""Source-of-truth Milvus field list for the sessions collection.
Used by both _ensure_collection() (create path) and _detect_schema_drift()
(startup validation) so the two can't drift out of sync.
"""
return [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=_EMBED_DIM),
FieldSchema(name="document", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="session_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="logical_session_id", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="provider", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="source_kind", dtype=DataType.VARCHAR, max_length=96),
FieldSchema(name="source_class", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="source_id", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="transcript_file", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="turn_index", dtype=DataType.INT64),
FieldSchema(name="timestamp", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="git_branch", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="project_root", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="issue_ids", dtype=DataType.VARCHAR, max_length=_ISSUE_ID_FIELD_MAX),
]
def detect_schema_drift(client: MilvusClient) -> List[str]:
"""Return a list of missing or extra field names if the live collection
schema differs from `_expected_schema_fields()`. Empty list = no drift.
Only field NAMES are diffed today — pymilvus's describe_collection output
shape varies across Milvus Lite vs Standalone, and we have not hit a
case where a same-named field changed dtype/length silently.
"""
if not client.has_collection(COLLECTION_NAME):
return []
try:
info = client.describe_collection(COLLECTION_NAME)
except Exception as exc:
print(f"Schema drift check skipped: describe_collection failed: {exc}", file=sys.stderr)
return []
expected = {f.name for f in _expected_schema_fields()}
actual: set[str] = set()
for field in info.get("fields", []) or []:
if isinstance(field, dict):
name = field.get("name")
else:
name = getattr(field, "name", None)
if name:
actual.add(name)
missing = sorted(expected - actual)
extra = sorted(actual - expected)
return [f"missing:{n}" for n in missing] + [f"extra:{n}" for n in extra]
def migrate_schema(client: MilvusClient, db_path: str = "") -> None:
"""Drop the sessions collection and recreate it with the current schema.
DESTRUCTIVE: all indexed turns are lost. Provided as the explicit recovery
path for `python cleanup.py migrate-schema` and for the auto-migrate
env opt-in (SESSIONFLOW_AUTO_MIGRATE_SCHEMA=1).
"""
if client.has_collection(COLLECTION_NAME):
print(
f"Dropping collection {COLLECTION_NAME!r} for schema migration "
"(all indexed turns will be lost)",
file=sys.stderr,
)
client.drop_collection(COLLECTION_NAME)
_create_collection(client, db_path)
# SESF-40: drop_collection already clears the data-path schema cache, so
# this is a no-op today. It funnels the one place schema mutates through the
# canonical invalidation hook so a future in-place add_collection_field path
# can't leave a warm persistent client raising code=65535 "field not exist".
_invalidate_schema_cache(client)
def _invalidate_schema_cache(client: MilvusClient) -> None:
"""Invalidate pymilvus's process-global data-path schema cache for the collection.
SESF-40: pymilvus 2.6 caches collection schemas in a class-level LRU
(``GlobalCache.schema``, keyed by endpoint/db/collection and consumed by
insert/upsert/search/hybrid_search via ``_get_schema``). It is invalidated on
``drop_collection`` but NOT on ``add_collection_field``, so an in-place field
add can leave a long-lived persistent client raising ``code=65535 "field not
exist"`` until LRU eviction or process restart — a fresh ``MilvusClient`` does
not clear it (the cache is a singleton). Any schema-mutating op must funnel
through here.
No public schema-cache refresh API exists, so this reaches into pymilvus
internals; every step is guarded and degrades to a logged no-op on a pymilvus
upgrade rather than breaking a real migration.
"""
try:
from pymilvus.client.cache import GlobalCache
except Exception as exc: # pragma: no cover - pymilvus internals moved
logger.debug("Schema cache invalidation skipped (no GlobalCache): %s", exc)
return
try:
endpoint = client._get_connection().server_address
# SessionFlow uses the default database; SchemaCache normalizes "" -> "default".
GlobalCache.schema.invalidate(endpoint, "", COLLECTION_NAME)
except Exception as exc: # pragma: no cover - defense in depth
logger.warning("Schema cache invalidation failed for %s: %s", COLLECTION_NAME, exc)
def _ensure_collection(client: MilvusClient, db_path: str = "") -> None:
"""Create the sessions collection if missing; refuse to start on schema drift.
SESF-11: previously this was create-if-missing only, so adding a field to
`_expected_schema_fields()` silently broke every insert with
DataNotMatchException against the pre-existing Milvus collection. Now:
- missing collection → create
- present + no drift → no-op
- present + drift → if SESSIONFLOW_AUTO_MIGRATE_SCHEMA=1 drop+recreate,
else raise RuntimeError telling the operator to run
`python cleanup.py migrate-schema`.
"""
if not client.has_collection(COLLECTION_NAME):
_create_collection(client, db_path)
return
drift = detect_schema_drift(client)
if not drift:
return
# SESF-38 AC-3: re-describe once (cache-free) before gating either branch.
# detect_schema_drift issues a fresh describe_collection, so a stale/cached
# first read that clears on the second describe must NOT raise or migrate.
# This re-verify gates BOTH the auto-migrate and the raise branches.
drift = detect_schema_drift(client)
if not drift:
return
auto = os.getenv("SESSIONFLOW_AUTO_MIGRATE_SCHEMA", "").lower() in {"1", "true", "yes", "on"}
if auto:
print(
f"SESSIONFLOW_AUTO_MIGRATE_SCHEMA detected schema drift {drift!r}; "
"dropping and recreating (all turns lost).",
file=sys.stderr,
)
migrate_schema(client, db_path)
return
raise RuntimeError(
f"Milvus collection {COLLECTION_NAME!r} schema is out of date "
f"(drift={drift}). First try the non-destructive option: restart the "
f"server — a transient describe can clear on a fresh read. If drift "
f"persists, recover with one of these DESTRUCTIVE options (both lose "
f"all turns): run `python cleanup.py migrate-schema` to drop and "
f"recreate it (destructive — all turns lost), or set "
f"SESSIONFLOW_AUTO_MIGRATE_SCHEMA=1 to migrate on startup "
f"(destructive — all turns lost)."
)
def _create_collection(client: MilvusClient, db_path: str = "") -> None:
print(f"Creating collection: {COLLECTION_NAME} (dim={_EMBED_DIM})", file=sys.stderr)
schema = CollectionSchema(fields=_expected_schema_fields())
index_params = client.prepare_index_params()
if _is_remote_uri(db_path):
# Standalone supports HNSW — O(log n) search vs O(n) FLAT.
index_params.add_index(
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 256},
)
else:
# Milvus Lite silently ignores non-FLAT indexes.
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
client.create_collection(
collection_name=COLLECTION_NAME,
schema=schema,
index_params=index_params,
)
print(f"Collection created: {COLLECTION_NAME}", file=sys.stderr)
# Standalone requires explicit load_collection before query/dedup paths work.
# create_collection with index_params does not auto-load.
if _is_remote_uri(db_path):
client.load_collection(collection_name=COLLECTION_NAME)
print(f"Collection loaded: {COLLECTION_NAME}", file=sys.stderr)
@contextmanager
def milvus_client_for_migration(db_path: Optional[str] = None):
"""Open a Milvus client WITHOUT _ensure_collection.
SESF-11: needed because _ensure_collection refuses to start on schema
drift — but the whole point of `cleanup.py migrate-schema` is to repair
that drift. This bypass MUST NOT be used outside migration code paths.
"""
path = _resolve_db_path(db_path)
if not _is_remote_uri(path):
Path(path).parent.mkdir(parents=True, exist_ok=True)
client = MilvusClient(path)
try:
yield client
finally:
client.close()
@contextmanager
def milvus_client(db_path: Optional[str] = None):
"""Get a Milvus client. In server mode, reuses persistent client."""
path = _resolve_db_path(db_path)
if _server_mode:
client = _get_persistent_client(path)
_ensure_collection(client, path)
yield client
else:
if not _is_remote_uri(path):
Path(path).parent.mkdir(parents=True, exist_ok=True)
client = MilvusClient(path)
_ensure_collection(client, path)
try:
yield client
finally:
client.close()
# --- Secret redaction guard (SESF-41) ---
# Truthy values for the SESSIONFLOW_REDACT on/off flag (boolean idiom, rag_engine
# precedent at the SESSIONFLOW_AUTO_MIGRATE_SCHEMA read).
_REDACT_TRUE = {"1", "true", "yes", "on"}
_REDACT_MODES = {"enforce", "report"}
# Durable, process-lifetime per-rule detection counts surfaced via get_stats under
# the "redaction" key (AC-10). Rule names only — never a secret value (AC-18).
# Guarded by _redaction_lock: the check-then-set update is a read-modify-write that
# would race under concurrent ingestion despite the GIL.
_redaction_counters: Dict[str, int] = {}
_redaction_lock = threading.Lock()
# mtime-keyed cache for the operator allowlist so a hot backfill path does not
# re-read + re-compile the file on every add_turns batch. {path: (mtime, patterns)}.
_allowlist_cache: Dict[str, tuple] = {}
def _redaction_settings() -> tuple[bool, str, Optional[str]]:
"""Read the redaction config from the environment (AC-11/12/13).
Returns:
``(enabled, mode, allowlist_path)``. ``SESSIONFLOW_REDACT`` unset defaults to
enabled in ``report`` mode; an explicit off value disables redaction.
"""
raw = os.getenv("SESSIONFLOW_REDACT")
enabled = True if raw is None else raw.strip().lower() in _REDACT_TRUE
mode = os.getenv("SESSIONFLOW_REDACT_MODE", "report").strip().lower()
if mode not in _REDACT_MODES:
mode = "report"
return enabled, mode, os.getenv("SESSIONFLOW_REDACT_ALLOWLIST")
def load_allowlist(path: Optional[str]) -> List[re.Pattern]:
"""Load operator allowlist regex patterns from ``path`` (one per line).
Impure on purpose: keeps file I/O out of the pure ``secret_redaction`` module
(D-4, AC-16). Blank lines and ``#`` comments are ignored; invalid patterns are
skipped with a warning. Returns an empty list when ``path`` is falsy/unreadable.
"""
if not path:
return []
try:
mtime = os.path.getmtime(path)
except OSError as exc:
logger.warning("Could not read redaction allowlist %s: %s", path, exc)
return []
cached = _allowlist_cache.get(path)
if cached is not None and cached[0] == mtime:
return cached[1]
patterns: List[re.Pattern] = []
try:
with open(path, "r", encoding="utf-8") as handle:
for line in handle:
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
try:
patterns.append(re.compile(stripped))
except re.error as exc:
logger.warning("Skipping invalid redaction allowlist pattern: %s", exc)
except OSError as exc:
logger.warning("Could not read redaction allowlist %s: %s", path, exc)
return []
_allowlist_cache[path] = (mtime, patterns)
return patterns
def _apply_redaction(turns: List[Dict]) -> None:
"""Redact secrets in ``turns`` in place before embed/store (SESF-41 hook).
Runs once over the already-deduped turns so all three durable sinks and the
async wrapper are covered with no per-Provider change (AC-1/2/3), and before
``_extract_issue_ids`` so issue IDs survive. In ``report`` mode it counts and
logs detections without mutating the text (AC-10); when disabled it is a no-op
(AC-12). Rule names only ever reach the log/counters (AC-18).
"""
enabled, mode, allowlist_path = _redaction_settings()
if not enabled:
return
allowlist = load_allowlist(allowlist_path)
rule_counts: Dict[str, int] = {}
for turn in turns:
redacted, hits = secret_redaction.redact(
turn.get("text", ""), mode=mode, allowlist=allowlist
)
for hit in hits:
rule_counts[hit.rule_name] = rule_counts.get(hit.rule_name, 0) + 1
if mode == "enforce":
turn["text"] = redacted
if rule_counts:
with _redaction_lock:
for rule_name, count in rule_counts.items():
_redaction_counters[rule_name] = _redaction_counters.get(rule_name, 0) + count
histogram = ", ".join(f"{name}={count}" for name, count in sorted(rule_counts.items()))
logger.info("Redaction (%s mode) detected: %s", mode, histogram)
def _redaction_status() -> Dict:
"""Return the operator-facing redaction status surface (AC-10)."""
enabled, mode, _ = _redaction_settings()
with _redaction_lock:
counts = dict(_redaction_counters)
return {"enabled": enabled, "mode": mode, "counts": counts}
def _scrub_exception(exc: BaseException) -> str:
"""Return the exception text with any secret value redacted (AC-17)."""
redacted, _ = secret_redaction.redact(str(exc), mode="enforce")
return redacted
def _scrub_exception_args(exc: BaseException) -> None:
"""Redact every string arg on ``exc`` in place, preserving non-string args (AC-17).
Scrubbing each string arg (rather than collapsing to a single message) keeps
status codes and other structured metadata in ``exc.args[1:]`` intact while
ensuring no secret survives in any stringified form of the re-raised exception.
"""
if exc.args:
exc.args = tuple(
secret_redaction.redact(arg, mode="enforce")[0] if isinstance(arg, str) else arg
for arg in exc.args
)
# --- Core operations ---
def add_turns(turns: List[Dict], db_path: Optional[str] = None) -> int:
"""Insert conversation turn chunks into Milvus. Dedup by doc_id.
Each turn dict should have:
text, doc_id, session_id, transcript_file, turn_index,
timestamp, git_branch, chunk_type
"""
if not turns:
return 0
# Dedup: check which doc_ids already exist
with milvus_client(db_path) as client:
existing_ids = set()
for turn in turns:
doc_id = turn["doc_id"]
try:
results = client.query(
collection_name=COLLECTION_NAME,
filter=f'doc_id == "{doc_id}"',
limit=1,
output_fields=["doc_id"],
)
if results:
existing_ids.add(doc_id)
except Exception as e:
logger.warning("Dedup check failed for doc_id %s: %s", doc_id, _scrub_exception(e))
new_turns = [t for t in turns if t["doc_id"] not in existing_ids]
if not new_turns:
return 0
# SESF-41: ingestion-time secret redaction guard. One hook over the deduped
# turns rewrites turn["text"] in place, covering all three durable sinks
# (embedding, Milvus document, FTS content) and add_turns_async with no
# per-Provider change, and runs before _extract_issue_ids so issue IDs survive.
_apply_redaction(new_turns)
# Embed texts in local, resource-controlled batches. Query embedding stays
# untouched in search(); this path is ingestion/backfill only.
budget = get_embedding_budget()
all_embeddings = []
for batch in budget.split_batches(new_turns):
texts = [t["text"] for t in batch]
decision = budget.before_batch(
batch_size=len(batch),
estimated_chars=sum(len(t) for t in texts),
)
if not decision.allowed and decision.retry_after_seconds > 0:
time.sleep(decision.retry_after_seconds)
decision = budget.before_batch(
batch_size=len(batch),
estimated_chars=sum(len(t) for t in texts),
)
if not decision.allowed:
logger.info("Embedding batch deferred: %s", decision.reason)
break
started = time.monotonic()
try:
embeddings = embed_texts(texts, is_query=False)
except Exception as e:
budget.after_batch(time.monotonic() - started, 0, error=e)
# SESF-41 AC-17: scrub the exception's string args before the bare
# re-raise so every upstream site that later stringifies it is already
# clean, while preserving any structured status-code args.
_scrub_exception_args(e)
raise
budget.after_batch(time.monotonic() - started, len(batch))
all_embeddings.extend(embeddings)
new_turns = new_turns[:len(all_embeddings)]
embeddings = all_embeddings
if not new_turns:
return 0
provider_defaults = default_provider_metadata()
data = []
for turn, emb in zip(new_turns, embeddings):
# Stable hash: SHA-256 truncated to int64. Python's hash() is
# randomized per process, so the same doc_id would get different
# primary keys across server restarts.
int_id = _pk_from_doc_id(turn["doc_id"])
data.append({
"id": int_id,
"vector": emb,
"document": _truncate_utf8(turn["text"], 65535),
"doc_id": turn["doc_id"],
"session_id": turn.get("session_id", ""),
"logical_session_id": turn.get("logical_session_id", turn.get("session_id", "")),
"provider": turn.get("provider", provider_defaults["provider"]),
"source_kind": turn.get("source_kind", provider_defaults["source_kind"]),
"source_class": turn.get("source_class", provider_defaults["source_class"]),
"source_id": turn.get("source_id", ""),
"source_path": turn.get("source_path", turn.get("transcript_file", "")),
"transcript_file": turn.get("transcript_file", ""),
"turn_index": turn.get("turn_index", 0),
"timestamp": turn.get("timestamp", ""),
"git_branch": turn.get("git_branch", ""),
"chunk_type": turn.get("chunk_type", "turn"),
"project_root": turn.get("project_root", ""),
"issue_ids": _extract_issue_ids(turn.get("text", "")),
})
with milvus_client(db_path) as client:
client.insert(collection_name=COLLECTION_NAME, data=data)
# Dual-write into FTS5 sidecar
try:
if db_path:
fts_conn = _fts.connection(db_path)
fts_records = [{
"doc_id": t["doc_id"],
"content": t["text"],
"session_id": t.get("session_id", ""),
"logical_session_id": t.get("logical_session_id", t.get("session_id", "")),
"provider": t.get("provider", provider_defaults["provider"]),
"source_kind": t.get("source_kind", provider_defaults["source_kind"]),
"source_class": t.get("source_class", provider_defaults["source_class"]),
"source_id": t.get("source_id", ""),
"source_path": t.get("source_path", t.get("transcript_file", "")),
"git_branch": t.get("git_branch", ""),
"turn_index": t.get("turn_index", 0),
"timestamp": t.get("timestamp", ""),
"chunk_type": t.get("chunk_type", "turn"),
"project_root": t.get("project_root", ""),
"issue_ids": _extract_issue_ids(t.get("text", "")),
} for t in new_turns]
_fts.insert(fts_conn, fts_records)
_fts.close_ephemeral(fts_conn)
except Exception as e:
# SESF-41 AC-17: the FTS payload holds Turn content, so scrub the exception
# text before logging so no secret fragment can echo through the warning.
logger.warning("FTS insert failed (non-fatal): %s", _scrub_exception(e))
return len(data)
def _escape_filter_scalar(value: str) -> str:
"""Escape a string value for use in a Milvus boolean-expression filter literal.
Milvus filter literals are C-style double-quoted strings (e.g. field ==
"value"). Per the Milvus expression grammar (Plan.g4), an embedded
double-quote is written as backslash-quote and a literal backslash as a
doubled backslash; Milvus does NOT honor ""-doubling.
Rules:
- NUL bytes are never valid in identifiers or scalar values; reject them
outright so a malformed input cannot truncate the filter expression.
- Each backslash is doubled, then each double-quote becomes backslash-quote.
Order matters: backslashes are escaped first, otherwise the backslash
introduced when escaping a quote would itself be doubled. This also stops
a trailing backslash from escaping the literal's closing quote and
consuming the rest of the filter expression (SESF-33).
- A literal newline or carriage return is not a valid character inside a
Milvus string literal (the grammar's DoubleSChar excludes them), so each
is rewritten to its escape form (backslash-n / backslash-r); tabs are
likewise escaped for consistency. These run after the backslash doubling
so the escape backslash they introduce is not itself re-doubled.
"""
if "\x00" in value:
raise ValueError("Filter scalar value must not contain NUL bytes")
value = value.replace("\\", "\\\\")
value = value.replace('"', '\\"')
return value.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
def _issue_id_containment_token(issue_id: str) -> str:
"""Normalize an issue id into a safe ``%,TOKEN,%`` containment token.
The token is matched as ``%,TOKEN,%`` against the comma-wrapped ``issue_ids``
field. Surrounding whitespace is stripped so ``" sesf-25 "`` can't yield a
never-matching ``%, SESF-25 ,%``; ``%``/``_`` are stripped so a malformed id
can't broaden the match into a wildcard scan; a valid token
(``[A-Z][A-Z0-9]+-\\d+``) contains none of these, so this is a no-op for
legitimate input. NUL bytes are rejected outright (mirroring
``_escape_filter_scalar``) since the FTS path consumes this token without
going through that guard. Shared by the Milvus filter and the FTS filter so
both halves of hybrid search stay in lockstep (SESF-32).
"""
if "\x00" in issue_id:
raise ValueError("issue_id must not contain NUL bytes")
return issue_id.strip().upper().replace("%", "").replace("_", "")
def _row_to_result(entity: Dict, defaults: Dict, distance: float = 1.0) -> Dict:
"""Map a Milvus entity dict to the standard internal result format.
Shared by ``search`` (vector hit entities) and ``_recent_listing`` (query
result rows) so any new schema field added to the collection propagates to