diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index f299de6537..cd11d77fda 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -70,6 +70,8 @@ model_service as model_service_v1beta1, model_monitoring as model_monitoring_v1beta1, operation as operation_v1beta1, + persistent_resource as persistent_resource_v1beta1, + persistent_resource_service as persistent_resource_service_v1beta1, pipeline_failure_policy as pipeline_failure_policy_v1beta1, pipeline_job as pipeline_job_v1beta1, pipeline_service as pipeline_service_v1beta1, @@ -216,6 +218,7 @@ model_service_v1, model_monitoring_v1, operation_v1, + persistent_resource_v1beta1, pipeline_failure_policy_v1, pipeline_job_v1, pipeline_service_v1, diff --git a/google/cloud/aiplatform/preview/persistent_resource.py b/google/cloud/aiplatform/preview/persistent_resource.py new file mode 100644 index 0000000000..0823af0db4 --- /dev/null +++ b/google/cloud/aiplatform/preview/persistent_resource.py @@ -0,0 +1,424 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, List, Optional, Union + +from google.api_core import operation +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.services import ( + persistent_resource_service_client_v1beta1 as persistent_resource_service_client_compat, +) +from google.cloud.aiplatform.compat.types import ( + persistent_resource_v1beta1 as gca_persistent_resource_compat, +) +from google.cloud.aiplatform_v1beta1.types import ( + encryption_spec as gca_encryption_spec, +) +from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore + +_LOGGER = base.Logger(__name__) + + +class PersistentResource(base.VertexAiResourceNounWithFutureManager): + """Managed PersistentResource feature for Vertex AI.""" + + client_class = utils.PersistentResourceClientWithOverride + _resource_noun = "persistentResource" + _getter_method = "get_persistent_resource" + _list_method = "list_persistent_resources" + _delete_method = "delete_persistent_resource" + _parse_resource_name_method = "parse_persistent_resource_path" + _format_resource_name_method = "persistent_resource_path" + + def __init__( + self, + persistent_resource_id: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves the PersistentResource and instantiates its representation. + + Args: + persistent_resource_id (str): + Required. + project (str): + Project this PersistentResource is in. Overrides + project set in aiplatform.init. + location (str): + Location this PersistentResource is in. Overrides + location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to manage this PersistentResource. + Overrides credentials set in aiplatform.init. + """ + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=persistent_resource_id, + ) + + self._gca_resource = self._get_gca_resource( + resource_name=persistent_resource_id + ) + + @property + def display_name(self) -> Optional[str]: + """The display name of the PersistentResource.""" + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "display_name", None) + + @property + def state(self) -> gca_persistent_resource_compat.PersistentResource.State: + """The state of the PersistentResource. + + Values: + STATE_UNSPECIFIED (0): + Not set. + PROVISIONING (1): + The PROVISIONING state indicates the + persistent resources is being created. + RUNNING (3): + The RUNNING state indicates the persistent + resources is healthy and fully usable. + STOPPING (4): + The STOPPING state indicates the persistent + resources is being deleted. + ERROR (5): + The ERROR state indicates the persistent resources may be + unusable. Details can be found in the ``error`` field. + """ + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "state", None) + + @property + def error(self) -> Optional[status_pb2.Status]: + """The error status of the PersistentResource. + + Only populated when the resource's state is ``STOPPING`` or ``ERROR``. + + """ + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "error", None) + + @property + def create_time(self) -> Optional[timestamp_pb2.Timestamp]: + """Time when the PersistentResource was created.""" + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "create_time", None) + + @property + def start_time(self) -> Optional[timestamp_pb2.Timestamp]: + """Time when the PersistentResource first entered the ``RUNNING`` state.""" + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "start_time", None) + + @property + def update_time(self) -> Optional[timestamp_pb2.Timestamp]: + """Time when the PersistentResource was most recently updated.""" + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "update_time", None) + + @property + def network(self) -> Optional[str]: + """The network peered with the PersistentResource. + + The full name of the Compute Engine + `network `__ to peered + with Vertex AI to host the persistent resources. + + For example, ``projects/12345/global/networks/myVPC``. + `Format `__ is of the + form ``projects/{project}/global/networks/{network}``. Where {project} + is a project number, as in ``12345``, and {network} is a network name. + + To specify this field, you must have already `configured VPC Network + Peering for Vertex + AI `__. + + If this field is left unspecified, the resources aren't peered with any + network. + """ + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "network", None) + + @classmethod + @base.optional_sync() + def create( + cls, + persistent_resource_id: str, + resource_pools: Union[ + List[Dict], List[gca_persistent_resource_compat.ResourcePool] + ], + display_name: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + network: Optional[str] = None, + kms_key_name: Optional[str] = None, + service_account: Optional[str] = None, + reserved_ip_ranges: List[str] = None, + sync: Optional[bool] = True, # pylint: disable=unused-argument + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "PersistentResource": + r"""Creates a PersistentResource. + + Args: + persistent_resource_id (str): + Required. The ID to use for the PersistentResource, + which become the final component of the + PersistentResource's resource name. + + The maximum length is 63 characters, and valid + characters are ``/^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$/``. + + This corresponds to the ``persistent_resource_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + resource_pools (MutableSequence[google.cloud.aiplatform_v1beta1.types.ResourcePool]): + Required. The list of resource pools to create for the + PersistentResource. + display_name (str): + Optional. The display name of the + PersistentResource. The name can be up to 128 + characters long and can consist of any UTF-8 + characters. + labels (MutableMapping[str, str]): + Optional. The labels with user-defined + metadata to organize PersistentResource. + + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + + See https://goo.gl/xmQnxf for more information + and examples of labels. + network (str): + Optional. The full name of the Compute Engine + `network `__ + to peered with Vertex AI to host the persistent resources. + For example, ``projects/12345/global/networks/myVPC``. + `Format `__ + is of the form + ``projects/{project}/global/networks/{network}``. Where + {project} is a project number, as in ``12345``, and + {network} is a network name. + + To specify this field, you must have already `configured VPC + Network Peering for Vertex + AI `__. + + If this field is left unspecified, the resources aren't + peered with any network. + kms_key_name (str): + Optional. Customer-managed encryption key for the + PersistentResource. If set, this PersistentResource and all + sub-resources of this PersistentResource will be secured by + this key. + service_account (str): + Optional. Default service account that this + PersistentResource's workloads run as. The workloads + including + + - Any runtime specified via ``ResourceRuntimeSpec`` on + creation time, for example, Ray. + - Jobs submitted to PersistentResource, if no other service + account specified in the job specs. + + Only works when custom service account is enabled and users + have the ``iam.serviceAccounts.actAs`` permission on this + service account. + reserved_ip_ranges (MutableSequence[str]): + Optional. A list of names for the reserved IP ranges under + the VPC network that can be used for this persistent + resource. + + If set, we will deploy the persistent resource within the + provided IP ranges. Otherwise, the persistent resource is + deployed to any IP ranges under the provided VPC network. + + Example ['vertex-ai-ip-range']. + sync (bool): + Whether to execute this method synchonously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + project (str): + Project to create this PersistentResource in. Overrides project + set in aiplatform.init. + location (str): + Location to create this PersistentResource in. Overrides + location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to create this PersistentResource. + Overrides credentials set in aiplatform.init. + + Returns: + persistent_resource (PersistentResource): + The object representation of the newly created + PersistentResource. + """ + + if labels: + utils.validate_labels(labels) + + gca_persistent_resource = gca_persistent_resource_compat.PersistentResource( + name=persistent_resource_id, + display_name=display_name, + resource_pools=resource_pools, + labels=labels, + network=network, + reserved_ip_ranges=reserved_ip_ranges, + ) + + if kms_key_name: + gca_persistent_resource.encryption_spec = ( + gca_encryption_spec.EncryptionSpec(kms_key_name=kms_key_name) + ) + + if service_account: + service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec( + enable_custom_service_account=True, service_account=service_account + ) + gca_persistent_resource.resource_runtime_spec = ( + gca_persistent_resource_compat.ResourceRuntimeSpec( + service_account_spec=service_account_spec + ) + ) + + api_client = cls._instantiate_client(location, credentials) + create_lro = cls._create( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + persistent_resource=gca_persistent_resource, + persistent_resource_id=persistent_resource_id, + ) + + _LOGGER.log_create_with_lro(cls, create_lro) + + create_lro.result(timeout=None) + persistent_resource_result = cls( + persistent_resource_id=persistent_resource_id, + project=project, + location=location, + credentials=credentials, + ) + + _LOGGER.log_create_complete( + cls, persistent_resource_result._gca_resource, "persistent resource" + ) + + return persistent_resource_result + + @classmethod + def _create( + cls, + api_client: ( + persistent_resource_service_client_compat.PersistentResourceServiceClient + ), + parent: str, + persistent_resource: gca_persistent_resource_compat.PersistentResource, + persistent_resource_id: str, + create_request_timeout: Optional[float] = None, + ) -> operation.Operation: + """Creates a PersistentResource directly calling the API client. + + Args: + api_client (PersistentResourceServiceClient): + An instance of PersistentResourceServiceClient with the correct + api_endpoint already set based on user's preferences. + parent (str): + Required. Also known as common location path, that usually contains the + project and location that the user provided to the upstream method. + IE "projects/my-project/locations/us-central1" + persistent_resource (gca_persistent_resource_compat.PersistentResource): + Required. The PersistentResource object to use for the create request. + persistent_resource_id (str): + Required. The ID to use for the PersistentResource, + which become the final component of the + PersistentResource's resource name. + + The maximum length is 63 characters, and valid + characters are ``/^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$/``. + + This corresponds to the ``persistent_resource_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + create_request_timeout (float): + Optional. The timeout for the create request in seconds. + + Returns: + operation (Operation): + The long-running operation returned by the Persistent Resource + create call. + """ + return api_client.create_persistent_resource( + parent=parent, + persistent_resource_id=persistent_resource_id, + persistent_resource=persistent_resource, + timeout=create_request_timeout, + ) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["PersistentResource"]: + """Lists a Persistent Resources on the provided project and region. + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[PersistentResource] + A list of PersistentResource objects. + """ + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py index 80faded765..7b7befa558 100644 --- a/tests/unit/aiplatform/constants.py +++ b/tests/unit/aiplatform/constants.py @@ -103,13 +103,14 @@ class TrainingJobConstants: _TEST_RUN_ARGS = ["-v", "0.1", "--test=arg"] + _TEST_MACHINE_SPEC = { + "machine_type": "n1-standard-4", + "accelerator_type": "NVIDIA_TESLA_K80", + "accelerator_count": 1, + } _TEST_WORKER_POOL_SPEC = [ { - "machine_spec": { - "machine_type": "n1-standard-4", - "accelerator_type": "NVIDIA_TESLA_K80", - "accelerator_count": 1, - }, + "machine_spec": _TEST_MACHINE_SPEC, "replica_count": 1, "disk_spec": {"boot_disk_type": "pd-ssd", "boot_disk_size_gb": 100}, "container_spec": { @@ -123,6 +124,7 @@ class TrainingJobConstants: _TEST_NETWORK = ( f"projects/{ProjectConstants._TEST_PROJECT}/global/networks/{_TEST_ID}" ) + _TEST_RESERVED_IP_RANGES = ["example_ip_range"] _TEST_TIMEOUT = 8000 _TEST_RESTART_JOB_ON_WORKER_RESTART = True _TEST_DISABLE_RETRIES = True @@ -369,3 +371,15 @@ class MatchingEngineConstants: _TEST_DISPLAY_NAME_UPDATE = "my new display name" _TEST_DESCRIPTION_UPDATE = "my description update" _TEST_REQUEST_METADATA = () + + +@dataclasses.dataclass(frozen=True) +class PersistentResourceConstants: + """Defines constants used by tests that create PersistentResource resources.""" + + _TEST_PERSISTENT_RESOURCE_ID = "test_persistent_resource_id" + _TEST_PERSISTENT_RESOURCE_DISPLAY_NAME = "test_display_name" + _TEST_RESOURCE_POOL = { + "machine_spec": TrainingJobConstants._TEST_MACHINE_SPEC, + "replica_count": 1, + } diff --git a/tests/unit/aiplatform/test_persistent_resource.py b/tests/unit/aiplatform/test_persistent_resource.py new file mode 100644 index 0000000000..b3480a6c0a --- /dev/null +++ b/tests/unit/aiplatform/test_persistent_resource.py @@ -0,0 +1,363 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import importlib +from unittest import mock + +from google.api_core import operation as ga_operation +from google.cloud import aiplatform +from google.cloud.aiplatform.compat.services import ( + persistent_resource_service_client_v1beta1, +) +from google.cloud.aiplatform.compat.types import encryption_spec_v1beta1 +from google.cloud.aiplatform.compat.types import ( + persistent_resource_service_v1beta1, +) +from google.cloud.aiplatform.compat.types import persistent_resource_v1beta1 +from google.cloud.aiplatform.preview import persistent_resource +import constants as test_constants +import pytest + + +_TEST_PROJECT = test_constants.ProjectConstants._TEST_PROJECT +_TEST_LOCATION = test_constants.ProjectConstants._TEST_LOCATION +_TEST_PARENT = test_constants.ProjectConstants._TEST_PARENT + +_TEST_PERSISTENT_RESOURCE_ID = ( + test_constants.PersistentResourceConstants._TEST_PERSISTENT_RESOURCE_ID +) +_TEST_PERSISTENT_RESOURCE_DISPLAY_NAME = ( + test_constants.PersistentResourceConstants._TEST_PERSISTENT_RESOURCE_DISPLAY_NAME +) +_TEST_LABELS = test_constants.ProjectConstants._TEST_LABELS +_TEST_NETWORK = test_constants.TrainingJobConstants._TEST_NETWORK +_TEST_RESERVED_IP_RANGES = test_constants.TrainingJobConstants._TEST_RESERVED_IP_RANGES +_TEST_KEY_NAME = test_constants.TrainingJobConstants._TEST_DEFAULT_ENCRYPTION_KEY_NAME +_TEST_SERVICE_ACCOUNT = test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT + +_TEST_PERSISTENT_RESOURCE_PROTO = persistent_resource_v1beta1.PersistentResource( + name=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], +) + + +def _get_persistent_resource_proto( + state=None, name=None, error=None +) -> persistent_resource_v1beta1.PersistentResource: + persistent_resource_proto = copy.deepcopy(_TEST_PERSISTENT_RESOURCE_PROTO) + persistent_resource_proto.name = name + persistent_resource_proto.state = state + persistent_resource_proto.error = error + return persistent_resource_proto + + +def _get_resource_name(name=None, project=_TEST_PROJECT, location=_TEST_LOCATION): + return "projects/{}/locations/{}/persistentResources/{}".format( + project, location, name + ) + + +@pytest.fixture +def create_preview_persistent_resource_mock(): + with mock.patch.object( + (persistent_resource_service_client_v1beta1.PersistentResourceServiceClient), + "create_persistent_resource", + ) as create_preview_persistent_resource_mock: + create_lro = mock.Mock(ga_operation.Operation) + create_lro.result.return_value = None + + create_preview_persistent_resource_mock.return_value = create_lro + yield create_preview_persistent_resource_mock + + +@pytest.fixture +def get_preview_persistent_resource_mock(): + with mock.patch.object( + (persistent_resource_service_client_v1beta1.PersistentResourceServiceClient), + "get_persistent_resource", + ) as get_preview_persistent_resource_mock: + get_preview_persistent_resource_mock.side_effect = [ + _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + state=(persistent_resource_v1beta1.PersistentResource.State.RUNNING), + ), + ] + + yield get_preview_persistent_resource_mock + + +_TEST_LIST_RESOURCE_1 = _get_persistent_resource_proto( + name="resource_1", + state=(persistent_resource_v1beta1.PersistentResource.State.RUNNING), +) +_TEST_LIST_RESOURCE_2 = _get_persistent_resource_proto( + name="resource_2", + state=(persistent_resource_v1beta1.PersistentResource.State.PROVISIONING), +) +_TEST_LIST_RESOURCE_3 = _get_persistent_resource_proto( + name="resource_3", + state=(persistent_resource_v1beta1.PersistentResource.State.STOPPING), +) +_TEST_LIST_RESOURCE_4 = _get_persistent_resource_proto( + name="resource_4", + state=(persistent_resource_v1beta1.PersistentResource.State.ERROR), +) + +_TEST_PERSISTENT_RESOURCE_LIST = [ + _TEST_LIST_RESOURCE_1, + _TEST_LIST_RESOURCE_2, + _TEST_LIST_RESOURCE_3, + _TEST_LIST_RESOURCE_4, +] + + +@pytest.fixture +def list_preview_persistent_resources_mock(): + with mock.patch.object( + (persistent_resource_service_client_v1beta1.PersistentResourceServiceClient), + "list_persistent_resources", + ) as list_preview_persistent_resources_mock: + list_preview_persistent_resources_mock.return_value = ( + _TEST_PERSISTENT_RESOURCE_LIST + ) + + yield list_preview_persistent_resources_mock + + +@pytest.fixture +def delete_preview_persistent_resource_mock(): + with mock.patch.object( + (persistent_resource_service_client_v1beta1.PersistentResourceServiceClient), + "delete_persistent_resource", + ) as delete_preview_persistent_resource_mock: + delete_lro = mock.Mock(ga_operation.Operation) + delete_lro.result.return_value = ( + persistent_resource_service_v1beta1.DeletePersistentResourceRequest() + ) + delete_preview_persistent_resource_mock.return_value = delete_lro + yield delete_preview_persistent_resource_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestPersistentResource: + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource( + self, + create_preview_persistent_resource_mock, + get_preview_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + display_name=_TEST_PERSISTENT_RESOURCE_DISPLAY_NAME, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + labels=_TEST_LABELS, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + + expected_persistent_resource_arg.display_name = ( + _TEST_PERSISTENT_RESOURCE_DISPLAY_NAME + ) + expected_persistent_resource_arg.labels = _TEST_LABELS + + create_preview_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + + get_preview_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_preview_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_with_network( + self, + create_preview_persistent_resource_mock, + get_preview_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + network=_TEST_NETWORK, + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + + expected_persistent_resource_arg.network = _TEST_NETWORK + expected_persistent_resource_arg.reserved_ip_ranges = _TEST_RESERVED_IP_RANGES + + create_preview_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_preview_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_preview_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_with_kms_key( + self, + create_preview_persistent_resource_mock, + get_preview_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + kms_key_name=_TEST_KEY_NAME, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + + expected_persistent_resource_arg.encryption_spec = ( + encryption_spec_v1beta1.EncryptionSpec(kms_key_name=_TEST_KEY_NAME) + ) + + create_preview_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_preview_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_preview_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_with_service_account( + self, + create_preview_persistent_resource_mock, + get_preview_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + + service_account_spec = persistent_resource_v1beta1.ServiceAccountSpec( + enable_custom_service_account=True, service_account=_TEST_SERVICE_ACCOUNT + ) + expected_persistent_resource_arg.resource_runtime_spec = ( + persistent_resource_v1beta1.ResourceRuntimeSpec( + service_account_spec=service_account_spec + ) + ) + + create_preview_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_preview_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_preview_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + + def test_list_persistent_resources(self, list_preview_persistent_resources_mock): + resource_list = persistent_resource.PersistentResource.list() + + list_preview_persistent_resources_mock.assert_called_once() + assert len(resource_list) == len(_TEST_PERSISTENT_RESOURCE_LIST) + + for i in range(len(resource_list)): + actual_resource = resource_list[i] + expected_resource = _TEST_PERSISTENT_RESOURCE_LIST[i] + + assert actual_resource.name == expected_resource.name + assert actual_resource.state == expected_resource.state + + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_persistent_resource( + self, + get_preview_persistent_resource_mock, + delete_preview_persistent_resource_mock, + sync, + ): + test_resource = persistent_resource.PersistentResource( + _TEST_PERSISTENT_RESOURCE_ID + ) + test_resource.delete(sync=sync) + + if not sync: + test_resource.wait() + + get_preview_persistent_resource_mock.assert_called_once() + delete_preview_persistent_resource_mock.assert_called_once_with( + name=_TEST_PERSISTENT_RESOURCE_ID, + )