Skip to content

Commit

Permalink
feat: Add FeatureNormType to MatchingEngineIndexConfig.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626348332
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Apr 19, 2024
1 parent c21b7eb commit c0e7acc
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
12 changes: 12 additions & 0 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def create_tree_ah_index(
encryption_spec_key_name: Optional[str] = None,
create_request_timeout: Optional[float] = None,
shard_size: Optional[str] = None,
feature_norm_type: Optional[
matching_engine_index_config.FeatureNormType
] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource that uses the tree-AH algorithm.
Expand Down Expand Up @@ -477,6 +480,8 @@ def create_tree_ah_index(
range 1-100, inclusive. The default value is 10 (means 10%) if not set.
distance_measure_type (matching_engine_index_config.DistanceMeasureType):
Optional. The distance measure used in nearest neighbor search.
feature_norm_type (matching_engine_index_config.FeatureNormType):
Optional. The feature norm type used in nearest neighbor search.
description (str):
Optional. The description of the Index.
labels (Dict[str, str]):
Expand Down Expand Up @@ -552,6 +557,7 @@ def create_tree_ah_index(
algorithm_config=algorithm_config,
approximate_neighbors_count=approximate_neighbors_count,
distance_measure_type=distance_measure_type,
feature_norm_type=feature_norm_type,
shard_size=shard_size,
)

Expand Down Expand Up @@ -580,6 +586,9 @@ def create_brute_force_index(
distance_measure_type: Optional[
matching_engine_index_config.DistanceMeasureType
] = None,
feature_norm_type: Optional[
matching_engine_index_config.FeatureNormType
] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
Expand Down Expand Up @@ -623,6 +632,8 @@ def create_brute_force_index(
Required. The number of dimensions of the input vectors.
distance_measure_type (matching_engine_index_config.DistanceMeasureType):
Optional. The distance measure used in nearest neighbor search.
feature_norm_type (matching_engine_index_config.FeatureNormType):
Optional. The feature norm type used in nearest neighbor search.
description (str):
Optional. The description of the Index.
labels (Dict[str, str]):
Expand Down Expand Up @@ -695,6 +706,7 @@ def create_brute_force_index(
dimensions=dimensions,
algorithm_config=algorithm_config,
distance_measure_type=distance_measure_type,
feature_norm_type=feature_norm_type,
shard_size=shard_size,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,15 @@ class MatchingEngineIndexConfig:
independently.
distance_measure_type (DistanceMeasureType):
Optional. The distance measure used in nearest neighbor search.
feature_norm_type (FeatureNormType):
Optional. The feature norm type used in nearest neighbor search.
"""

dimensions: int
algorithm_config: AlgorithmConfig
approximate_neighbors_count: Optional[int] = None
distance_measure_type: Optional[DistanceMeasureType] = None
feature_norm_type: Optional[FeatureNormType] = None
shard_size: Optional[str] = None

def as_dict(self) -> Dict[str, Any]:
Expand All @@ -144,6 +147,7 @@ def as_dict(self) -> Dict[str, Any]:
"algorithmConfig": self.algorithm_config.as_dict(),
"approximateNeighborsCount": self.approximate_neighbors_count,
"distanceMeasureType": self.distance_measure_type,
"featureNormType": self.feature_norm_type,
"shardSize": self.shard_size,
}
return res
13 changes: 13 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
_TEST_CONTENTS_DELTA_URI = "gs://contents"
_TEST_INDEX_DISTANCE_MEASURE_TYPE = "SQUARED_L2_DISTANCE"
_TEST_INDEX_FEATURE_NORM_TYPE = "UNIT_L2_NORM"

_TEST_CONTENTS_DELTA_URI_UPDATE = "gs://contents_update"
_TEST_IS_COMPLETE_OVERWRITE_UPDATE = True
Expand Down Expand Up @@ -374,6 +375,7 @@ def test_create_tree_ah_index(
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
feature_norm_type=_TEST_INDEX_FEATURE_NORM_TYPE,
leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT,
leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT,
description=_TEST_INDEX_DESCRIPTION,
Expand Down Expand Up @@ -403,6 +405,7 @@ def test_create_tree_ah_index(
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"featureNormType": _TEST_INDEX_FEATURE_NORM_TYPE,
"shardSize": shard_size,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
Expand Down Expand Up @@ -447,6 +450,7 @@ def test_create_tree_ah_index_with_empty_index(
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
feature_norm_type=_TEST_INDEX_FEATURE_NORM_TYPE,
leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT,
leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT,
description=_TEST_INDEX_DESCRIPTION,
Expand Down Expand Up @@ -476,6 +480,7 @@ def test_create_tree_ah_index_with_empty_index(
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"featureNormType": _TEST_INDEX_FEATURE_NORM_TYPE,
"shardSize": shard_size,
},
},
Expand Down Expand Up @@ -506,6 +511,7 @@ def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
feature_norm_type=_TEST_INDEX_FEATURE_NORM_TYPE,
leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT,
leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT,
description=_TEST_INDEX_DESCRIPTION,
Expand All @@ -527,6 +533,7 @@ def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"featureNormType": _TEST_INDEX_FEATURE_NORM_TYPE,
"shardSize": None,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
Expand Down Expand Up @@ -564,6 +571,7 @@ def test_create_brute_force_index(
contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
feature_norm_type=_TEST_INDEX_FEATURE_NORM_TYPE,
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
sync=sync,
Expand All @@ -586,6 +594,7 @@ def test_create_brute_force_index(
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": None,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"featureNormType": _TEST_INDEX_FEATURE_NORM_TYPE,
"shardSize": shard_size,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
Expand Down Expand Up @@ -627,6 +636,7 @@ def test_create_brute_force_index_with_empty_index(
display_name=_TEST_INDEX_DISPLAY_NAME,
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
feature_norm_type=_TEST_INDEX_FEATURE_NORM_TYPE,
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
sync=sync,
Expand All @@ -648,6 +658,7 @@ def test_create_brute_force_index_with_empty_index(
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": None,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"featureNormType": _TEST_INDEX_FEATURE_NORM_TYPE,
"shardSize": None,
},
},
Expand Down Expand Up @@ -677,6 +688,7 @@ def test_create_brute_force_index_backward_compatibility(self, create_index_mock
contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
feature_norm_type=_TEST_INDEX_FEATURE_NORM_TYPE,
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
)
Expand All @@ -691,6 +703,7 @@ def test_create_brute_force_index_backward_compatibility(self, create_index_mock
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": None,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
"featureNormType": _TEST_INDEX_FEATURE_NORM_TYPE,
"shardSize": None,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
Expand Down

0 comments on commit c0e7acc

Please sign in to comment.