Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Vector Database Service to allow stages to read and write to VDBs #1225

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5d7b3cb
Added milvus vdb prototype impl
bsuryadevara Sep 26, 2023
4807f3d
Added milvus vdb prototype impl
bsuryadevara Sep 26, 2023
b1f94fb
Added llamaindex and langchain prototypes
bsuryadevara Sep 27, 2023
d912645
doc updates
bsuryadevara Sep 27, 2023
4ecd37f
updates to milvus vd service
bsuryadevara Sep 30, 2023
c18125a
updated search and upsert functions
bsuryadevara Oct 2, 2023
a6ef60e
Added write_to_vector_db stage
bsuryadevara Oct 3, 2023
7389542
Added tests to get started
bsuryadevara Oct 3, 2023
3a31cee
Added tests to get started
bsuryadevara Oct 3, 2023
4cfba55
Added MilvusClient extension class to support missing functions
bsuryadevara Oct 4, 2023
b83f517
Added tests for Milvus vector database serivce
bsuryadevara Oct 4, 2023
b7fee57
Added tests for Milvus vector database service
bsuryadevara Oct 4, 2023
cde18b2
Added tests for Milvus vector database service
bsuryadevara Oct 4, 2023
c9316c0
Added milvus lite to pipeline tests
bsuryadevara Oct 9, 2023
36f1f18
Added tests with milvus lite
bsuryadevara Oct 11, 2023
2f24cc2
Updated Milvus VDB tests
bsuryadevara Oct 11, 2023
9670c97
Merge remote-tracking branch 'upstream/branch-23.11' into 1177-fea-ad…
bsuryadevara Oct 11, 2023
e4b8a02
Updated Milvus VDB tests
bsuryadevara Oct 11, 2023
a5e742e
Added tests with milvus lite
bsuryadevara Oct 11, 2023
3d0e01b
Renamed a file
bsuryadevara Oct 11, 2023
cd52a5f
Feedback changes
bsuryadevara Oct 12, 2023
5ce3402
Feedback changes
bsuryadevara Oct 12, 2023
9e6989a
Removed register stage decorator
bsuryadevara Oct 12, 2023
cf327b5
Ignore pymilvus in the docs
bsuryadevara Oct 13, 2023
a6a6f43
Update variable names
bsuryadevara Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added tests for Milvus vector database serivce
  • Loading branch information
bsuryadevara committed Oct 4, 2023
commit b83f517d430975ede0df01bdd47c9ba0378af657
25 changes: 14 additions & 11 deletions morpheus/service/milvus_vector_db_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,10 @@ def create(self, name: str, overwrite: bool = False, **kwargs: dict[str, typing.

index_param = None

if overwrite:
if self.has_store_object(name):
if not self.has_store_object(name) or overwrite:
if overwrite and self.has_store_object(name):
self.drop(name)

if not self.has_store_object(name):

if len(schema_fields_conf) == 0:
raise ValueError("Cannot create collection as provided empty schema_fields configuration")

Expand Down Expand Up @@ -164,10 +162,12 @@ def _collection_insert(self,
name: str,
data: typing.Union[list[list], list[dict], dict],
**kwargs: dict[str, typing.Any]) -> None:

if not self.has_store_object(name):
raise RuntimeError(f"Collection {name} doesn't exist.")

collection = None
try:
if not self.has_store_object(name):
raise RuntimeError(f"Collection {name} doesn't exist.")
collection_conf = kwargs.get("collection_conf", {})
partition_name = collection_conf.get("partition_name", "_default")

Expand Down Expand Up @@ -237,8 +237,8 @@ def search(self, name: str, query: typing.Union[str, dict] = None, **kwargs: dic
------
RuntimeError
If an error occurs during the search operation.
ValueError
If query argument is `None` and data keyword argument doesn't exist.
If query argument is `None` and `data` keyword argument doesn't exist.
If `data` keyword arguement is `None`.
"""

try:
Expand All @@ -247,12 +247,15 @@ def search(self, name: str, query: typing.Union[str, dict] = None, **kwargs: dic
result = self._client.query(collection_name=name, filter=query, **kwargs)
else:
if "data" not in kwargs:
raise ValueError("The search operation requires that search vectors be " +
"provided as a keyword argument 'data'.")
raise RuntimeError("The search operation requires that search vectors be " +
"provided as a keyword argument 'data'")
if kwargs["data"] is None:
raise RuntimeError("Argument 'data' cannot be None")

data = kwargs.pop("data")
result = self._client.search(collection_name=name, data=data, **kwargs)
self._client.release_collection(collection_name=name)
return result

except MilvusException as exec_info:
raise RuntimeError(f"Unable to perform serach: {exec_info}") from exec_info

Expand Down
238 changes: 238 additions & 0 deletions tests/test_milvus_vector_db_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#!/usr/bin/env python
# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import typing
from os import path
from unittest.mock import Mock
from unittest.mock import patch

import pytest

import cudf

from _utils import TEST_DIRS
from morpheus.service.milvus_client import MilvusClient
from morpheus.service.milvus_vector_db_service import MilvusVectorDBService


@pytest.fixture(scope="function", name="mock_milvus_client_fixture")
def mock_milvus_client() -> MilvusClient:
with patch('morpheus.service.milvus_vector_db_service.MilvusClient') as mock_client:
yield mock_client.return_value


@pytest.fixture(scope="function", name="milvus_service_fixture")
def milvus_service(mock_milvus_client_fixture) -> MilvusVectorDBService:
mock_milvus_client_fixture.has_collection.return_value = True
service = MilvusVectorDBService(uri="http://localhost:19530")
return service


@pytest.mark.parametrize(
"has_store_object_input, expected_result",
[
(True, True), # Collection exists
(False, False), # Collection does not exist
])
def test_has_store_object(milvus_service_fixture: MilvusVectorDBService,
mock_milvus_client_fixture: MilvusClient,
has_store_object_input: bool,
expected_result: bool):
mock_milvus_client_fixture.has_collection.return_value = has_store_object_input
assert milvus_service_fixture.has_store_object("test_collection") == expected_result


@pytest.mark.parametrize(
"list_collections_input, expected_result",
[
(["collection1", "collection2"], ["collection1", "collection2"]), # Collections exist
([], []), # No collections exist
])
def test_list_store_objects(milvus_service_fixture: MilvusVectorDBService,
mock_milvus_client_fixture: MilvusClient,
list_collections_input: list,
expected_result: list):
mock_milvus_client_fixture.list_collections.return_value = list_collections_input
assert milvus_service_fixture.list_store_objects() == expected_result


@pytest.mark.parametrize("overwrite, has_collection", [(True, True), (False, False), (True, False)])
def test_create(milvus_service_fixture: MilvusVectorDBService,
mock_milvus_client_fixture: MilvusClient,
overwrite: bool,
has_collection: bool):
filepath = path.join(TEST_DIRS.tests_data_dir, "service", "milvus_test_collection_conf.json")
collection_config = {}
with open(filepath, "r") as file:
collection_config = json.load(file)

mock_milvus_client_fixture.has_collection.return_value = has_collection
name = collection_config.pop("name")
milvus_service_fixture.create(name=name, overwrite=overwrite, **collection_config)

if overwrite:
if has_collection:
mock_milvus_client_fixture.drop_collection.assert_called_once()
else:
mock_milvus_client_fixture.drop_collection.assert_not_called()
else:
mock_milvus_client_fixture.drop_collection.assert_not_called()

mock_milvus_client_fixture.create_collection_with_schema.assert_called_once()


def test_insert(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
data = [
{
"id": 1, "embedding": [0.1, 0.2, 0.3], "age": 30
},
{
"id": 2, "embedding": [0.4, 0.5, 0.6], "age": 25
},
]
milvus_service_fixture.insert(name="test_collection", data=data)
mock_milvus_client_fixture.get_collection.assert_called_once()


@pytest.mark.parametrize(
"insert_data, expected_exception",
[
([], RuntimeError), # Collection does not exist
([], RuntimeError), # Other error scenario
])
def test_insert_error(milvus_service_fixture: MilvusVectorDBService,
mock_milvus_client_fixture: MilvusClient,
insert_data: list,
expected_exception: Exception):
mock_milvus_client_fixture.has_collection.return_value = False
with pytest.raises(expected_exception):
milvus_service_fixture.insert("non_existent_collection", data=insert_data)


def test_insert_dataframe(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
mock_insert = Mock()
mock_milvus_client_fixture.get_collection.return_value.insert = mock_insert

data = cudf.DataFrame({
"id": [1, 2],
"embedding": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
"age": [30, 25],
})

milvus_service_fixture.insert_dataframe(name="test_collection", df=data)
mock_insert.assert_called_once()


search_test_cases = [
("query", {
"bool": {
"must": [{
"vector": {
"embedding": [0.1, 0.2, 0.3]
}
}]
}
}, "query_result"),
("data", [{
"id": 1, "embedding": [0.1, 0.2, 0.3]
}, {
"id": 2, "embedding": [0.4, 0.5, 0.6]
}], "data_result"),
("error_value", None, None),
("error_value", None, []),
]


@pytest.mark.parametrize("test_type, query_or_data, expected_result", search_test_cases)
def test_search(milvus_service_fixture: MilvusVectorDBService,
mock_milvus_client_fixture: MilvusClient,
test_type: str,
query_or_data: dict,
expected_result: typing.Any):
if test_type == "query":
mock_milvus_client_fixture.query.return_value = {"result": expected_result}
elif test_type == "data":
mock_milvus_client_fixture.search.return_value = {"result": expected_result}

if test_type == "error_value":
with pytest.raises(RuntimeError):
milvus_service_fixture.search(name="test_collection", query=query_or_data, data=None)
else:
result = milvus_service_fixture.search(name="test_collection", **{test_type: query_or_data})
assert result == {"result": expected_result}
mock_milvus_client_fixture.load_collection.assert_called_once()
if test_type == "query":
mock_milvus_client_fixture.query.assert_called_once()
elif test_type == "data":
mock_milvus_client_fixture.search.assert_called_once()
mock_milvus_client_fixture.release_collection.assert_called_once()


def test_update(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
data = [
{
"id": 1, "embedding": [0.1, 0.2, 0.3], "age": 30
},
{
"id": 2, "embedding": [0.4, 0.5, 0.6], "age": 25
},
]
milvus_service_fixture.update(name="test_collection", data=data)
mock_milvus_client_fixture.upsert.assert_called_once()


def test_delete_by_keys(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
keys = [1, 2]
mock_milvus_client_fixture.delete.return_value = keys

result = milvus_service_fixture.delete_by_keys(name="test_collection", keys=keys)
assert result == keys
mock_milvus_client_fixture.delete.assert_called_once()


def test_delete(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
expr = "age < 30"
milvus_service_fixture.delete(name="test_collection", expr=expr)
mock_milvus_client_fixture.delete_by_expr.assert_called_once()


def test_retrieve_by_keys(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
keys = [1]
mock_milvus_client_fixture.get.return_value = [{"id": 1, "embedding": [0.1, 0.2, 0.3], "age": 30}]

result = milvus_service_fixture.retrieve_by_keys(name="test_collection", keys=keys)
assert result == [{"id": 1, "embedding": [0.1, 0.2, 0.3], "age": 30}]
mock_milvus_client_fixture.get.assert_called_once()


def test_count(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
mock_milvus_client_fixture.num_entities.return_value = 5
count = milvus_service_fixture.count(name="test_collection")
assert count == 5
mock_milvus_client_fixture.num_entities.assert_called_once()


def test_drop(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
milvus_service_fixture.drop(name="test_collection")
mock_milvus_client_fixture.drop_collection.assert_called_once()


def test_describe(milvus_service_fixture: MilvusVectorDBService, mock_milvus_client_fixture: MilvusClient):
mock_milvus_client_fixture.describe_collection.return_value = {"name": "test_collection"}
description = milvus_service_fixture.describe(name="test_collection")
assert description == {"name": "test_collection"}
mock_milvus_client_fixture.describe_collection.assert_called_once()
Loading
Loading