Skip to content

feat(vectordb): adding qdrant vector db support #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ ee/ui-component/.next

ui-component/notebook-storage/notebooks.json
ee/ui-component/package-lock.json

morphik.dev.toml
11 changes: 4 additions & 7 deletions core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from core.storage.local_storage import LocalStorage
from core.storage.s3_storage import S3Storage
from core.vector_store.multi_vector_store import MultiVectorStore
from core.vector_store.pgvector_store import PGVectorStore
from core.vector_store import vector_store_factory

# Initialize FastAPI app
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -173,9 +173,7 @@ async def lifespan(app_instance: FastAPI):
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for pgvector store")

vector_store = PGVectorStore(
uri=settings.POSTGRES_URI,
)
vector_store = vector_store_factory(settings)

# Initialize storage
match settings.STORAGE_PROVIDER:
Expand Down Expand Up @@ -260,7 +258,7 @@ async def lifespan(app_instance: FastAPI):
completion_model=completion_model,
cache_factory=cache_factory,
reranker=reranker,
enable_colpali=settings.ENABLE_COLPALI,
enable_colpali=settings.COLPALI_MODE != "off",
colpali_embedding_model=colpali_embedding_model,
colpali_vector_store=colpali_vector_store,
)
Expand Down Expand Up @@ -2050,8 +2048,7 @@ async def set_folder_rule(
except Exception as rule_apply_error:
last_error = rule_apply_error
logger.warning(
f"Metadata extraction attempt {retry_count + 1} failed: "
f"{rule_apply_error}"
f"Metadata extraction attempt {retry_count + 1} failed: {rule_apply_error}"
)
if retry_count == max_retries - 1: # Last attempt
logger.error(f"All {max_retries} metadata extraction attempts failed")
Expand Down
24 changes: 12 additions & 12 deletions core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ class Settings(BaseSettings):
S3_BUCKET: Optional[str] = None

# Vector store configuration
VECTOR_STORE_PROVIDER: Literal["pgvector"]
VECTOR_STORE_PROVIDER: Literal["pgvector", "qdrant"]
VECTOR_STORE_DATABASE_NAME: Optional[str] = None
QDRANT_HOST: Optional[str] = None
QDRANT_PORT: int = 6333
QDRANT_HTTPS: bool = False

# Colpali configuration
ENABLE_COLPALI: bool
# Colpali embedding mode: off, local, or api
COLPALI_MODE: Literal["off", "local", "api"] = "local"

Expand Down Expand Up @@ -139,7 +141,8 @@ def get_settings() -> Settings:
load_dotenv(override=True)

# Load config.toml
with open("morphik.toml", "rb") as f:
cfg_path = os.environ.get("MORPHIK_CONFIG_PATH", "morphik.toml")
with open(cfg_path, "rb") as f:
config = tomli.load(f)

em = "'{missing_value}' needed if '{field}' is set to '{value}'"
Expand Down Expand Up @@ -281,14 +284,12 @@ def get_settings() -> Settings:
raise ValueError(f"Unknown storage provider selected: '{prov}'")

# load vector store config
vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]}
if vector_store_config["VECTOR_STORE_PROVIDER"] != "pgvector":
prov = vector_store_config["VECTOR_STORE_PROVIDER"]
raise ValueError(f"Unknown vector store provider selected: '{prov}'")

if "POSTGRES_URI" not in os.environ:
msg = em.format(missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector")
raise ValueError(msg)
vector_store_config = {
"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"],
"QDRANT_HOST": config["vector_store"]["qdrant_host"],
"QDRANT_PORT": config["vector_store"]["qdrant_port"],
"QDRANT_HTTPS": config["vector_store"]["qdrant_https"],
}

# load rules config
rules_config = {
Expand All @@ -303,7 +304,6 @@ def get_settings() -> Settings:

# load morphik config
morphik_config = {
"ENABLE_COLPALI": config["morphik"]["enable_colpali"],
"COLPALI_MODE": config["morphik"].get("colpali_mode", "local"),
"MODE": config["morphik"].get("mode", "cloud"), # Default to "cloud" mode
# API domain for core server
Expand Down
4 changes: 2 additions & 2 deletions core/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async def retrieve_chunks(
chunks = await self.reranker.rerank(query, chunks)
chunks.sort(key=lambda x: x.score, reverse=True)
chunks = chunks[:k]
logger.debug(f"Reranked {k*10} chunks and selected the top {k}")
logger.debug(f"Reranked {k * 10} chunks and selected the top {k}")

# Combine multiple chunk sources if needed
chunks = await self._combine_multi_and_regular_chunks(
Expand Down Expand Up @@ -1210,7 +1210,7 @@ async def store_document_with_retry():
current_retry_delay *= 2
else:
logger.error(
f"All database connection attempts failed " f"after {max_retries} retries: {error_msg}"
f"All database connection attempts failed after {max_retries} retries: {error_msg}"
)
raise Exception("Failed to store document metadata after multiple retries")
else:
Expand Down
18 changes: 18 additions & 0 deletions core/vector_store/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from core.config import Settings
from .base_vector_store import BaseVectorStore
from .pgvector_store import PGVectorStore
from .qdrant_store import QdrantVectorStore


def vector_store_factory(settings: Settings) -> BaseVectorStore:
prov = settings.VECTOR_STORE_PROVIDER
if prov == "pgvector":
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for pgvector store")
return PGVectorStore(uri=settings.POSTGRES_URI)
elif prov == "qdrant":
if not settings.QDRANT_HOST:
raise ValueError("Qdrant host is required for qdrant store")
return QdrantVectorStore(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT, https=settings.QDRANT_HTTPS)
else:
raise ValueError(f"Unknown vector store provider selected: '{prov}'")
1 change: 0 additions & 1 deletion core/vector_store/pgvector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ async def initialize(self):

# Continue with the rest of the initialization
async with self.engine.begin() as conn:

# Check if vector_embeddings table exists
check_table_sql = """
SELECT EXISTS (
Expand Down
193 changes: 193 additions & 0 deletions core/vector_store/qdrant_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import json
import logging
from typing import List, Literal, Optional, Tuple, cast
import uuid

from qdrant_client import AsyncQdrantClient
from qdrant_client.models import models

from core.models.chunk import DocumentChunk

from .base_vector_store import BaseVectorStore

logger = logging.getLogger(__name__)
QDRANT_COLLECTION_NAME = "vector_embeddings"


def _to_point_id(doc_id: str, chunk_number: int):
return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{chunk_number}.{doc_id}.internal"))


def _get_qdrant_distance(metric: Literal["cosine", "dotProduct"]) -> models.Distance:
match metric:
case "cosine":
return models.Distance.COSINE
case "dotProduct":
return models.Distance.DOT


class QdrantVectorStore(BaseVectorStore):
def __init__(self, host: str, port: int, https: bool) -> None:
from core.config import get_settings

settings = get_settings()

self.dimensions = settings.VECTOR_DIMENSIONS
self.collection_name = QDRANT_COLLECTION_NAME
self.distance = _get_qdrant_distance(settings.EMBEDDING_SIMILARITY_METRIC)
self.client = AsyncQdrantClient(
host=host,
port=port,
https=https,
)

async def _create_collection(self):
return await self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=self.dimensions,
distance=self.distance,
on_disk=True,
),
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
always_ram=True,
),
),
)

async def _check_collection_vector_size(self):
collection = await self.client.get_collection(self.collection_name)
params = collection.config.params
assert params.vectors is not None
vectors = cast(models.VectorParams, params.vectors)
if vectors.size != self.dimensions:
msg = f"Vector collection changed from {vectors.size} to {self.dimensions}. This requires recreating tables and will delete all existing vector data."
logger.error(msg)
raise ValueError(msg)
return True

async def initialize(self):
logger.info("Initialize qdrant vector collection")
try:
if not await self.client.collection_exists(self.collection_name):
logger.info("Detected no collection exists. Creating qdrant collection")
await self._create_collection()
else:
await self._check_collection_vector_size()

await self.client.create_payload_index(
self.collection_name,
"document_id",
models.PayloadSchemaType.UUID,
)
return True
except Exception as e:
logger.error(f"Error initializing Qdrant store: {str(e)}")
return False

async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
try:
batch = [
models.PointStruct(
id=_to_point_id(chunk.document_id, chunk.chunk_number),
vector=cast(List[float], chunk.embedding),
payload={
"document_id": chunk.document_id,
"chunk_number": chunk.chunk_number,
"content": chunk.content,
"metadata": json.dumps(chunk.metadata) if chunk.metadata is not None else "{}",
},
)
for chunk in chunks
]
await self.client.upsert(collection_name=self.collection_name, points=batch)
return True, [cast(str, p.id) for p in batch]
except Exception as e:
logger.error(f"Error storing embeddings: {str(e)}")
return False, []

async def query_similar(
self,
query_embedding: List[float],
k: int,
doc_ids: Optional[List[str]] = None,
) -> List[DocumentChunk]:
try:
query = None
if doc_ids is not None:
query = models.Filter(
must=models.FieldCondition(
key="document_id",
match=models.MatchAny(any=doc_ids),
),
)

resp = await self.client.query_points(
self.collection_name,
query=query_embedding,
limit=k,
query_filter=query,
with_payload=True,
)
return [
DocumentChunk(
document_id=p.payload["document_id"],
chunk_number=p.payload["chunk_number"],
content=p.payload["content"],
embedding=[],
metadata=json.loads(p.payload["metadata"]),
score=p.score,
)
for p in resp.points
if p.payload is not None
]
except Exception as e:
logger.error(f"Error querying similar chunks: {str(e)}")
return []

async def get_chunks_by_id(
self,
chunk_identifiers: List[Tuple[str, int]],
) -> List[DocumentChunk]:
try:
if not chunk_identifiers:
return []

ids = [_to_point_id(doc_id, chunk_number) for (doc_id, chunk_number) in chunk_identifiers]
resp = await self.client.retrieve(
self.collection_name,
ids=ids,
)
return [
DocumentChunk(
document_id=p.payload["document_id"],
chunk_number=p.payload["chunk_number"],
content=p.payload["content"],
embedding=[],
metadata=json.loads(p.payload["metadata"]),
score=0,
)
for p in resp
if p.payload is not None
]
except Exception as e:
logger.error(f"Error retrieving chunks by ID: {str(e)}")
return []

async def delete_chunks_by_document_id(self, document_id: str) -> bool:
try:
await self.client.delete(
self.collection_name,
points_selector=models.Filter(
must=models.FieldCondition(
key="document_id",
match=models.MatchValue(value=document_id),
),
),
)
return True
except Exception as e:
logger.error(f"Error deleting chunks for document {document_id}: {str(e)}")
return False
15 changes: 8 additions & 7 deletions core/workers/ingestion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from core.storage.local_storage import LocalStorage
from core.storage.s3_storage import S3Storage
from core.vector_store.multi_vector_store import MultiVectorStore
from core.vector_store.pgvector_store import PGVectorStore
from core.vector_store import vector_store_factory

# Enterprise routing helpers
from ee.db_router import get_database_for_app, get_vector_store_for_app
Expand Down Expand Up @@ -71,7 +71,7 @@ async def get_document_with_retry(document_service, document_id, auth, max_retri
try:
doc = await document_service.db.get_document(document_id, auth)
if doc:
logger.debug(f"Successfully retrieved document {document_id} on attempt {attempt+1}")
logger.debug(f"Successfully retrieved document {document_id} on attempt {attempt + 1}")
return doc

# Document not found but no exception raised
Expand Down Expand Up @@ -221,7 +221,7 @@ async def process_ingestion_job(
file_content = file_content.read()
download_time = time.time() - download_start
phase_times["download_file"] = download_time
logger.info(f"File download took {download_time:.2f}s for {len(file_content)/1024/1024:.2f}MB")
logger.info(f"File download took {download_time:.2f}s for {len(file_content) / 1024 / 1024:.2f}MB")

# 4. Parse file to text
parse_start = time.time()
Expand Down Expand Up @@ -417,9 +417,10 @@ async def process_ingestion_job(
# Only process if it's an image chunk - pass the image content to the rule
if chunk_obj.metadata.get("is_image", False):
# Get metadata *and* the potentially modified chunk
chunk_rule_metadata, processed_chunk = (
await document_service.rules_processor.process_chunk_rules(chunk_obj, image_rules)
)
(
chunk_rule_metadata,
processed_chunk,
) = await document_service.rules_processor.process_chunk_rules(chunk_obj, image_rules)
processed_chunks_multivector.append(processed_chunk)
# Aggregate the metadata extracted from this chunk
aggregated_chunk_metadata.update(chunk_rule_metadata)
Expand Down Expand Up @@ -602,7 +603,7 @@ async def startup(ctx):

# Initialize vector store
logger.info("Initializing primary vector store...")
vector_store = PGVectorStore(uri=settings.POSTGRES_URI)
vector_store = vector_store_factory(settings)
success = await vector_store.initialize()
if success:
logger.info("Primary vector store initialization successful")
Expand Down
Loading