Skip to content

Commit ffe9cde

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add matching engine deployment tier parameter and new shard size
PiperOrigin-RevId: 817012629
1 parent 4ca9fcc commit ffe9cde

File tree

4 files changed

+62
-1
lines changed

4 files changed

+62
-1
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def create_tree_ah_index(
588588
SHARD_SIZE_SMALL
589589
SHARD_SIZE_MEDIUM
590590
SHARD_SIZE_LARGE
591+
SHARD_SIZE_SO_DYNAMIC
591592
592593
593594
Returns:
@@ -740,6 +741,7 @@ def create_brute_force_index(
740741
SHARD_SIZE_SMALL
741742
SHARD_SIZE_MEDIUM
742743
SHARD_SIZE_LARGE
744+
SHARD_SIZE_SO_DYNAMIC
743745
744746
Returns:
745747
MatchingEngineIndex - Index resource object

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,7 @@ def _build_deployed_index(
944944
auth_config_audiences: Optional[Sequence[str]] = None,
945945
auth_config_allowed_issuers: Optional[Sequence[str]] = None,
946946
psc_automation_configs: Optional[Sequence[Tuple[str, str]]] = None,
947+
deployment_tier: Optional[str] = None,
947948
) -> gca_matching_engine_index_endpoint.DeployedIndex:
948949
"""Builds a DeployedIndex.
949950
@@ -1046,6 +1047,8 @@ def _build_deployed_index(
10461047
projects/{project}/global/networks/{network}, where
10471048
{project} is a project number, as in '12345', and {network}
10481049
is network name.
1050+
deployment_tier (str):
1051+
Optional. The deployment tier that the index is deployed to.
10491052
10501053
"""
10511054

@@ -1056,6 +1059,7 @@ def _build_deployed_index(
10561059
enable_access_logging=enable_access_logging,
10571060
reserved_ip_ranges=reserved_ip_ranges,
10581061
deployment_group=deployment_group,
1062+
deployment_tier=deployment_tier,
10591063
)
10601064

10611065
if auth_config_audiences and auth_config_allowed_issuers:
@@ -1115,6 +1119,7 @@ def deploy_index(
11151119
sync: bool = True,
11161120
deploy_request_timeout: Optional[float] = None,
11171121
psc_automation_configs: Optional[Sequence[Tuple[str, str]]] = None,
1122+
deployment_tier: Optional[str] = None,
11181123
) -> "MatchingEngineIndexEndpoint":
11191124
"""Deploys an existing index resource to this endpoint resource.
11201125
@@ -1231,6 +1236,8 @@ def deploy_index(
12311236
[(project_id_1, network_1), (project_id_1, network_2))] will enable
12321237
PSC automation for the index to be deployed to project_id_1's network_1
12331238
and network_2 and can be queried within these networks.
1239+
deployment_tier (str):
1240+
Optional. The deployment tier that the index is deployed to.
12341241
Returns:
12351242
MatchingEngineIndexEndpoint - IndexEndpoint resource object
12361243
"""
@@ -1250,6 +1257,7 @@ def deploy_index(
12501257
sync=sync,
12511258
deploy_request_timeout=deploy_request_timeout,
12521259
psc_automation_configs=psc_automation_configs,
1260+
deployment_tier=deployment_tier,
12531261
)
12541262

12551263
@base.optional_sync(return_input_arg="self")
@@ -1270,6 +1278,7 @@ def _deploy_index(
12701278
sync: bool = True,
12711279
deploy_request_timeout: Optional[float] = None,
12721280
psc_automation_configs: Optional[Sequence[Tuple[str, str]]] = None,
1281+
deployment_tier: Optional[str] = None,
12731282
) -> "MatchingEngineIndexEndpoint":
12741283
"""Helper method to deploy an existing index resource to this endpoint resource.
12751284
@@ -1386,6 +1395,8 @@ def _deploy_index(
13861395
[(project_id_1, network_1), (project_id_1, network_2))] will enable
13871396
PSC automation for the index to be deployed to project_id_1's network_1
13881397
and network_2 and can be queried within these networks.
1398+
deployment_tier (str):
1399+
Optional. The deployment tier that the index is deployed to.
13891400
Returns:
13901401
MatchingEngineIndexEndpoint - IndexEndpoint resource object
13911402
"""
@@ -1411,6 +1422,7 @@ def _deploy_index(
14111422
auth_config_audiences=auth_config_audiences,
14121423
auth_config_allowed_issuers=auth_config_allowed_issuers,
14131424
psc_automation_configs=psc_automation_configs,
1425+
deployment_tier=deployment_tier,
14141426
)
14151427

14161428
deploy_lro = self.api_client.deploy_index(

tests/unit/aiplatform/test_matching_engine_index.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@
7777
_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT = 150
7878
_TEST_LEAF_NODE_EMBEDDING_COUNT = 123
7979
_TEST_LEAF_NODES_TO_SEARCH_PERCENT = 50
80-
_TEST_SHARD_SIZES = ["SHARD_SIZE_SMALL", "SHARD_SIZE_LARGE", "SHARD_SIZE_MEDIUM"]
80+
_TEST_SHARD_SIZES = [
81+
"SHARD_SIZE_SMALL",
82+
"SHARD_SIZE_LARGE",
83+
"SHARD_SIZE_MEDIUM",
84+
"SHARD_SIZE_SO_DYNAMIC",
85+
]
8186

8287
_TEST_INDEX_DESCRIPTION = test_constants.MatchingEngineConstants._TEST_INDEX_DESCRIPTION
8388

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
"service-account-name-1@project-id.iam.gserviceaccount.com",
104104
"service-account-name-2@project-id.iam.gserviceaccount.com",
105105
]
106+
_TEST_DEPLOYMENT_TIER = "STORAGE"
106107
_TEST_SIGNED_JWT = "signed_jwt"
107108
_TEST_AUTHORIZATION_METADATA = (("authorization", f"Bearer: {_TEST_SIGNED_JWT}"),)
108109

@@ -1324,6 +1325,47 @@ def test_deploy_index_psc_automation_configs(self, deploy_index_mock, sync):
13241325
timeout=_TEST_TIMEOUT,
13251326
)
13261327

1328+
@pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock")
1329+
@pytest.mark.parametrize("sync", [True, False])
1330+
def test_deploy_index_deployment_tier(self, deploy_index_mock, sync):
1331+
aiplatform.init(project=_TEST_PROJECT)
1332+
1333+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1334+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1335+
)
1336+
1337+
# Get index
1338+
my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_NAME)
1339+
1340+
my_index_endpoint = my_index_endpoint.deploy_index(
1341+
index=my_index,
1342+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1343+
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
1344+
deployment_tier=_TEST_DEPLOYMENT_TIER,
1345+
request_metadata=_TEST_REQUEST_METADATA,
1346+
sync=sync,
1347+
deploy_request_timeout=_TEST_TIMEOUT,
1348+
)
1349+
1350+
if not sync:
1351+
my_index_endpoint.wait()
1352+
1353+
deploy_index_mock.assert_called_once_with(
1354+
index_endpoint=my_index_endpoint.resource_name,
1355+
deployed_index=gca_index_endpoint.DeployedIndex(
1356+
id=_TEST_DEPLOYED_INDEX_ID,
1357+
index=my_index.resource_name,
1358+
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
1359+
automatic_resources={
1360+
"min_replica_count": None,
1361+
"max_replica_count": None,
1362+
},
1363+
deployment_tier=_TEST_DEPLOYMENT_TIER,
1364+
),
1365+
metadata=_TEST_REQUEST_METADATA,
1366+
timeout=_TEST_TIMEOUT,
1367+
)
1368+
13271369
@pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock")
13281370
def test_mutate_deployed_index(self, mutate_deployed_index_mock):
13291371
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)