Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Vector Database Service to allow stages to read and write to VDBs #1225

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5d7b3cb
Added milvus vdb prototype impl
bsuryadevara Sep 26, 2023
4807f3d
Added milvus vdb prototype impl
bsuryadevara Sep 26, 2023
b1f94fb
Added llamaindex and langchain prototypes
bsuryadevara Sep 27, 2023
d912645
doc updates
bsuryadevara Sep 27, 2023
4ecd37f
updates to milvus vd service
bsuryadevara Sep 30, 2023
c18125a
updated search and upsert functions
bsuryadevara Oct 2, 2023
a6ef60e
Added write_to_vector_db stage
bsuryadevara Oct 3, 2023
7389542
Added tests to get started
bsuryadevara Oct 3, 2023
3a31cee
Added tests to get started
bsuryadevara Oct 3, 2023
4cfba55
Added MilvusClient extension class to support missing functions
bsuryadevara Oct 4, 2023
b83f517
Added tests for Milvus vector database serivce
bsuryadevara Oct 4, 2023
b7fee57
Added tests for Milvus vector database service
bsuryadevara Oct 4, 2023
cde18b2
Added tests for Milvus vector database service
bsuryadevara Oct 4, 2023
c9316c0
Added milvus lite to pipeline tests
bsuryadevara Oct 9, 2023
36f1f18
Added tests with milvus lite
bsuryadevara Oct 11, 2023
2f24cc2
Updated Milvus VDB tests
bsuryadevara Oct 11, 2023
9670c97
Merge remote-tracking branch 'upstream/branch-23.11' into 1177-fea-ad…
bsuryadevara Oct 11, 2023
e4b8a02
Updated Milvus VDB tests
bsuryadevara Oct 11, 2023
a5e742e
Added tests with milvus lite
bsuryadevara Oct 11, 2023
3d0e01b
Renamed a file
bsuryadevara Oct 11, 2023
cd52a5f
Feedback changes
bsuryadevara Oct 12, 2023
5ce3402
Feedback changes
bsuryadevara Oct 12, 2023
9e6989a
Removed register stage decorator
bsuryadevara Oct 12, 2023
cf327b5
Ignore pymilvus in the docs
bsuryadevara Oct 13, 2023
a6a6f43
Update variable names
bsuryadevara Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added milvus vdb prototype impl
  • Loading branch information
bsuryadevara committed Sep 26, 2023
commit 5d7b3cbc24ec285d16e1017a6b8182019e86b75c
1 change: 1 addition & 0 deletions docker/conda/environments/cuda11.8_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@ dependencies:
- pip:
# Add additional dev dependencies here
- pytest-kafka==0.6.0
- pymilvus==2.3.1
204 changes: 204 additions & 0 deletions morpheus/controllers/milvus_vector_db_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading
import time

import pandas as pd
from pymilvus import BulkInsertState
from pymilvus import Collection
from pymilvus import CollectionSchema
from pymilvus import DataType
from pymilvus import FieldSchema
from pymilvus import connections
from pymilvus import utility
bsuryadevara marked this conversation as resolved.
Show resolved Hide resolved

from morpheus.controllers.vector_db_controller import VectorDBController
from morpheus.controllers.vector_db_controller import with_mutex

logger = logging.getLogger(__name__)


class MilvusVectorDBController(VectorDBController):
"""
"""

def __init__(self, host: str, port: str, alias: str = "default", pool_size: int = 1, **kwargs):
self._data_type_dict = {
"int": DataType.INT64,
"bool": DataType.BOOL,
"float": DataType.FLOAT,
"double": DataType.DOUBLE,
"binary_vector": DataType.BINARY_VECTOR,
"float_vector": DataType.FLOAT_VECTOR
}
self._alias = alias
connections.connect(host=host, port=port, alias=self._alias, pool_size=pool_size, **kwargs)

def has_collection(self, name) -> bool:
return utility.has_collection(name)

def list_collections(self) -> list[str]:
return utility.list_collections()

def _create_index(self, collection, field_name, index_params) -> None:
collection.create_index(field_name=field_name, index_params=index_params)

def _create_schema_field(self, field_conf: dict):
dtype = self._data_type_dict[field_conf["dtype"].lower()]
dim = field_conf.get("dim", None)

if (dtype == DataType.BINARY_VECTOR or dtype == DataType.FLOAT_VECTOR):
if not dim:
raise ValueError(f"Dimensions for {dtype} should not be None")
if not isinstance(dim, int):
raise ValueError(f"Dimensions for {dtype} should be an integer")

field_schema = FieldSchema(name=field_conf["name"],
dtype=dtype,
description=field_conf.get("description", ""),
is_primary=field_conf["is_primary"],
dim=dim)
return field_schema

@with_mutex("_mutex")
def create_collection(self, collection_config):
collection_conf = collection_config.get("collection_conf")
collection_name = collection_conf.get("name")
index_conf = collection_conf.get("index_conf", None)
partition_conf = collection_conf.get("partition_conf", None)

schema_conf = collection_conf.get("schema_conf")
schema_fields_conf = schema_conf.get("schema_fields")

if not self.has_collection(collection_name):

if len(schema_fields_conf) == 0:
raise ValueError("Cannot create collection as provided empty schema_fields configuration")

schema_fields = [self._create_schema_field(field_conf=field_conf) for field_conf in schema_fields_conf]

schema = CollectionSchema(fields=schema_fields,
auto_id=schema_conf.get("auto_id", False),
description=schema_conf.get("description", ""))
collection = Collection(name=collection_name,
schema=schema,
using=self._alias,
shards_num=collection_conf.get("shards", 2),
consistency_level=collection_conf.get("consistency_level", "Strong"))

if partition_conf:
# Iterate over each partition configuration
for part in partition_conf:
collection.create_partition(part["name"], description=part.get("description", ""))
if index_conf:
self._create_index(collection=collection,
field_name=index_conf["field_name"],
index_params=index_conf["index_params"])

@with_mutex("_mutex")
def insert(self, name, data, **kwargs):

partition_name = kwargs.get("partition_name", "_default")

if isinstance(data, list):
if not self.has_collection(name):
raise ValueError(f"Collection {name} doesn't exist.")
collection = Collection(name=name)

collection.insert(data, partition_name=partition_name)
collection.flush()

# TODO (Bhargav): Load input data from a file
# if isinstance(data, str):
# task_id = utility.do_bulk_insert(collection_name=name,
# partition_name=kwargs.get("partition_name", None),
# files=[data])
#
# while True:
# time.sleep(2)
# state = utility.get_bulk_insert_state(task_id=task_id)
# if state.state == BulkInsertState.ImportFailed or state.state == BulkInsertState.ImportFailedAndCleaned:
# raise Exception(f"The task {state.task_id} failed, reason: {state.failed_reason}")

# if state.state >= BulkInsertState.ImportCompleted:
# break

elif isinstance(data, pd.DataFrame):
collection_conf = kwargs.get("collection_conf")
index_conf = collection_conf.get("index_conf", None)
params = collection_conf.get("params", {})

collection, _ = Collection.construct_from_dataframe(
collection_conf["name"],
data,
primary_field=collection_conf["primary_field"],
auto_id=collection_conf.get("auto_id", False),
description=collection_conf.get("description", None),
partition_name=partition_name,
**params
)

if index_conf:
self._create_index(collection=collection,
field_name=index_conf["field_name"],
index_params=index_conf["index_params"])

collection.flush()
else:
raise ValueError("Unsupported data type for insertion.")

@with_mutex("_mutex")
bsuryadevara marked this conversation as resolved.
Show resolved Hide resolved
def search(self, name, query=None, **kwargs):
is_partition_load = kwargs.get("is_partition_load", False)

collection = Collection(name=name)

try:
if is_partition_load:
partitions = kwargs.get("partitions")
collection.load(partitions)
else:
collection.load()

if query:
result = collection.query(expr=query, **kwargs)
else:
result = collection.search(**kwargs)

return result

except Exception as exec_info:
raise RuntimeError(f"Error while performing search: {exec_info}") from exec_info
finally:
collection.release()

@with_mutex("_mutex")
def drop(self, name, **kwargs):

type = kwargs.get("type", "collection")

collection = Collection(name=name)
if type == "index":
collection.drop_index()
elif type == "partition":
partition_name = kwargs["partition_name"]
collection.drop_partition(partition_name)
else:
collection.drop()

@with_mutex("_mutex")
def close(self):
connections.remove_connection(alias=self._alias)
29 changes: 29 additions & 0 deletions morpheus/controllers/qdrant_vector_db_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from qdrant_openapi_client.qdrant_client import QdrantClient

from morpheus.controllers.vector_db_controller import VectorDatabaseController


class QdrantVectorDBController(VectorDatabaseController):

def __init__(self, api_url='http://localhost:6333'):
self.client = QdrantClient(api_url=api_url)

def insert(self, name, data, **kwargs):
pass

def search(self, name, query=None, **kwargs):
pass
123 changes: 123 additions & 0 deletions morpheus/controllers/vector_db_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from abc import ABC
from abc import abstractmethod


def with_mutex(lock_name):

def decorator(func):

def wrapper(*args, **kwargs):
with getattr(args[0], lock_name):
return func(*args, **kwargs)

return wrapper

return decorator


class VectorDBController(ABC):

_mutex = threading.Lock()
"""
Abstract class for vector document store implementation.
"""

@abstractmethod
def insert(self, name, data, **kwargs):
"""
"""
pass

@abstractmethod
def search(self, name, query=None, **kwargs):
"""
"""
pass

@abstractmethod
def drop(self):
"""
"""
pass

# @abstractmethod
# def update(self, vector: list[float], vector_id: str):
# """
# Update an existing vector in the vector document store.

# Parameters:
# - vector (List[float]): The updated vector data.
# - vector_id (str): The unique identifier of the vector to update.

# Returns:
# - None

# Raises:
# - RuntimeError: If an error occurs while updating the vector.
# """
# pass

# @abstractmethod
# def get_by_name(self, name: str) -> list[float]:
# """
# Retrieve a vector from the vector document store by its ID.

# Parameters:
# - vector_id (str): The unique identifier of the vector to retrieve.

# Returns:
# - List[float]: The vector data.

# Raises:
# - RuntimeError: If an error occurs while retrieving the vector.
# """
# pass

# @abstractmethod
# def count(self) -> int:
# """
# Get the total count of vectors in the vector document store.

# Returns:
# - int: The total count of vectors.

# Raises:
# - RuntimeError: If an error occurs while counting the vectors.
# """
# pass

@abstractmethod
def create_collection(self, collection_config: dict):
"""
Create an index on the vector document store for efficient querying
"""

@abstractmethod
def close(self, **kwargs):
"""
"""

@abstractmethod
def has_collection(self, name) -> bool:
"""
"""

@abstractmethod
def list_collections(self) -> list[str]:
"""
"""
Loading
Loading