Skip to content

Commit

Permalink
feat: Adding Vertex Vector Search Vector DB option for RAG corpuses t…
Browse files Browse the repository at this point in the history
…o SDK

PiperOrigin-RevId: 675710933
  • Loading branch information
speedstorm1 authored and copybara-github committed Sep 17, 2024
1 parent 07e471e commit f882657
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 6 deletions.
24 changes: 24 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
JiraSource,
JiraQuery,
Weaviate,
VertexVectorSearch,
VertexFeatureStore,
)
from google.cloud.aiplatform_v1beta1 import (
Expand Down Expand Up @@ -78,6 +79,12 @@
index_name=TEST_PINECONE_INDEX_NAME,
api_key=TEST_PINECONE_API_KEY_SECRET_VERSION,
)
TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT = "test-vector-search-index-endpoint"
TEST_VERTEX_VECTOR_SEARCH_INDEX = "test-vector-search-index"
TEST_VERTEX_VECTOR_SEARCH_CONFIG = VertexVectorSearch(
index_endpoint=TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT,
index=TEST_VERTEX_VECTOR_SEARCH_INDEX,
)
TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME = "test-feature-view-resource-name"
TEST_GAPIC_RAG_CORPUS = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
Expand Down Expand Up @@ -115,6 +122,17 @@
),
),
)
TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
rag_vector_db_config=RagVectorDbConfig(
vertex_vector_search=RagVectorDbConfig.VertexVectorSearch(
index_endpoint=TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT,
index=TEST_VERTEX_VECTOR_SEARCH_INDEX,
),
),
)
TEST_GAPIC_RAG_CORPUS_PINECONE = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
Expand Down Expand Up @@ -158,6 +176,12 @@
description=TEST_CORPUS_DISCRIPTION,
vector_db=TEST_PINECONE_CONFIG,
)
TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG,
)
TEST_PAGE_TOKEN = "test-page-token"

# RagFiles
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@ def create_rag_corpus_mock_vertex_feature_store():
yield create_rag_corpus_mock_vertex_feature_store


@pytest.fixture
def create_rag_corpus_mock_vertex_vector_search():
with mock.patch.object(
VertexRagDataServiceClient,
"create_rag_corpus",
) as create_rag_corpus_mock_vertex_vector_search:
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
create_rag_corpus_lro_mock.done.return_value = True
create_rag_corpus_lro_mock.result.return_value = (
tc.TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH
)
create_rag_corpus_mock_vertex_vector_search.return_value = (
create_rag_corpus_lro_mock
)
yield create_rag_corpus_mock_vertex_vector_search


@pytest.fixture
def create_rag_corpus_mock_pinecone():
with mock.patch.object(
Expand Down Expand Up @@ -257,6 +274,15 @@ def test_create_corpus_vertex_feature_store_success(self):

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE)

@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_vector_search")
def test_create_corpus_vertex_vector_search_success(self):
rag_corpus = rag.create_corpus(
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
vector_db=tc.TEST_VERTEX_VECTOR_SEARCH_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH)

@pytest.mark.usefixtures("create_rag_corpus_mock_pinecone")
def test_create_corpus_pinecone_success(self):
rag_corpus = rag.create_corpus(
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
SlackChannel,
SlackChannelsSource,
VertexFeatureStore,
VertexVectorSearch,
Weaviate,
)

Expand All @@ -64,6 +65,7 @@
"SlackChannelsSource",
"VertexFeatureStore",
"VertexRagStore",
"VertexVectorSearch",
"Weaviate",
"create_corpus",
"delete_corpus",
Expand Down
5 changes: 4 additions & 1 deletion vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
RagFile,
SlackChannelsSource,
VertexFeatureStore,
VertexVectorSearch,
Weaviate,
)

Expand All @@ -58,7 +59,9 @@ def create_corpus(
display_name: Optional[str] = None,
description: Optional[str] = None,
embedding_model_config: Optional[EmbeddingModelConfig] = None,
vector_db: Optional[Union[Weaviate, VertexFeatureStore, Pinecone]] = None,
vector_db: Optional[
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone]
] = None,
) -> RagCorpus:
"""Creates a new RagCorpus resource.
Expand Down
24 changes: 20 additions & 4 deletions vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
SlackChannelsSource,
JiraSource,
VertexFeatureStore,
VertexVectorSearch,
Weaviate,
)

Expand Down Expand Up @@ -99,8 +100,8 @@ def convert_gapic_to_embedding_model_config(

def convert_gapic_to_vector_db(
gapic_vector_db: RagVectorDbConfig,
) -> Union[Weaviate, VertexFeatureStore, Pinecone]:
"""Convert Gapic RagVectorDbConfig to Weaviate, VertexFeatureStore, or Pinecone."""
) -> Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone]:
"""Convert Gapic RagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, or Pinecone."""
if gapic_vector_db.__contains__("weaviate"):
return Weaviate(
weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint,
Expand All @@ -116,6 +117,11 @@ def convert_gapic_to_vector_db(
index_name=gapic_vector_db.pinecone.index_name,
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
)
elif gapic_vector_db.__contains__("vertex_vector_search"):
return VertexVectorSearch(
index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint,
index=gapic_vector_db.vertex_vector_search.index,
)
else:
return None

Expand Down Expand Up @@ -418,7 +424,7 @@ def set_embedding_model_config(


def set_vector_db(
vector_db: Union[Weaviate, VertexFeatureStore, Pinecone],
vector_db: Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone],
rag_corpus: GapicRagCorpus,
) -> None:
"""Sets the vector db configuration for the rag corpus."""
Expand Down Expand Up @@ -446,6 +452,16 @@ def set_vector_db(
feature_view_resource_name=resource_name,
),
)
elif isinstance(vector_db, VertexVectorSearch):
index_endpoint = vector_db.index_endpoint
index = vector_db.index

rag_corpus.rag_vector_db_config = RagVectorDbConfig(
vertex_vector_search=RagVectorDbConfig.VertexVectorSearch(
index_endpoint=index_endpoint,
index=index,
),
)
elif isinstance(vector_db, Pinecone):
index_name = vector_db.index_name
api_key = vector_db.api_key
Expand All @@ -462,5 +478,5 @@ def set_vector_db(
)
else:
raise TypeError(
"vector_db must be a Weaviate, VertexFeatureStore, or Pinecone."
"vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, or Pinecone."
)
21 changes: 20 additions & 1 deletion vertexai/preview/rag/utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ class VertexFeatureStore:
resource_name: str


@dataclasses.dataclass
class VertexVectorSearch:
"""VertexVectorSearch.
Attributes:
index_endpoint (str):
The resource name of the Index Endpoint. Format:
``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}``
index (str):
The resource name of the Index. Format:
``projects/{project}/locations/{location}/indexes/{index}``
"""

index_endpoint: str
index: str


@dataclasses.dataclass
class Pinecone:
"""Pinecone.
Expand Down Expand Up @@ -129,7 +146,9 @@ class RagCorpus:
display_name: Optional[str] = None
description: Optional[str] = None
embedding_model_config: Optional[EmbeddingModelConfig] = None
vector_db: Optional[Union[Weaviate, VertexFeatureStore, Pinecone]] = None
vector_db: Optional[
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone]
] = None


@dataclasses.dataclass
Expand Down

0 comments on commit f882657

Please sign in to comment.