diff --git a/src/k8s-extension/HISTORY.rst b/src/k8s-extension/HISTORY.rst index f9012d4007f..9bc3cc21746 100644 --- a/src/k8s-extension/HISTORY.rst +++ b/src/k8s-extension/HISTORY.rst @@ -2,6 +2,9 @@ Release History =============== +1.0.1 +++++++++++++++++++ +* microsoft.azureml.kubernetes: Retrieve relay and service bus connection string when update the configuration protected settings of the extension. 1.0.0 ++++++++++++++++++ diff --git a/src/k8s-extension/azext_k8s_extension/custom.py b/src/k8s-extension/azext_k8s_extension/custom.py index f51b1a26c89..d3fd501a0f2 100644 --- a/src/k8s-extension/azext_k8s_extension/custom.py +++ b/src/k8s-extension/azext_k8s_extension/custom.py @@ -203,7 +203,7 @@ def update_k8s_extension(cmd, client, resource_group_name, cluster_name, name, c # Get the extension class based on the extension type extension_class = ExtensionFactory(extension_type_lower) - upd_extension = extension_class.Update(auto_upgrade_minor_version, release_train, version, + upd_extension = extension_class.Update(cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, config_settings, config_protected_settings) return sdk_no_wait(no_wait, client.begin_update, resource_group_name, cluster_rp, cluster_type, diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py index d6de36bb246..706973d9a13 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py @@ -20,7 +20,7 @@ import azure.mgmt.storage.models import azure.mgmt.loganalytics import azure.mgmt.loganalytics.models -from azure.cli.core.azclierror import InvalidArgumentValueError, MutuallyExclusiveArgumentError +from azure.cli.core.azclierror import AzureResponseError, InvalidArgumentValueError, MutuallyExclusiveArgumentError, ResourceNotFoundError from azure.cli.core.commands.client_factory import get_mgmt_service_client, get_subscription_id from azure.mgmt.resource.locks.models import ManagementLockObject from knack.log import get_logger @@ -31,7 +31,8 @@ from ..vendored_sdks.models import ( Extension, Scope, - ScopeCluster + ScopeCluster, + PatchExtension ) logger = get_logger(__name__) @@ -45,7 +46,6 @@ def __init__(self): # constants for configuration settings. self.DEFAULT_RELEASE_NAMESPACE = 'azureml' self.RELAY_CONNECTION_STRING_KEY = 'relayserver.relayConnectionString' - self.RELAY_CONNECTION_STRING_DEPRECATED_KEY = 'RelayConnectionString' # for 3rd party deployment, will be deprecated self.HC_RESOURCE_ID_KEY = 'relayserver.hybridConnectionResourceID' self.RELAY_HC_NAME_KEY = 'relayserver.hybridConnectionName' self.SERVICE_BUS_CONNECTION_STRING_KEY = 'servicebus.connectionString' @@ -86,7 +86,7 @@ def __init__(self): # reference mapping self.reference_mapping = { - self.RELAY_SERVER_CONNECTION_STRING: [self.RELAY_CONNECTION_STRING_KEY, self.RELAY_CONNECTION_STRING_DEPRECATED_KEY], + self.RELAY_SERVER_CONNECTION_STRING: [self.RELAY_CONNECTION_STRING_KEY], self.SERVICE_BUS_CONNECTION_STRING: [self.SERVICE_BUS_CONNECTION_STRING_KEY], 'cluster_name': ['clusterId', 'prometheus.prometheusSpec.externalLabels.cluster_name'], } @@ -164,6 +164,82 @@ def Delete(self, cmd, client, resource_group_name, cluster_name, name, cluster_t "Please try to reinstall device plugins to fix this issue.") user_confirmation_factory(cmd, yes) + def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings, + configuration_protected_settings): + self.__normalize_config(configuration_settings, configuration_protected_settings) + + if len(configuration_protected_settings) > 0: + subscription_id = get_subscription_id(cmd.cli_ctx) + + if self.AZURE_LOG_ANALYTICS_CONNECTION_STRING not in configuration_protected_settings: + try: + _, shared_key = _get_log_analytics_ws_connection_string( + cmd, subscription_id, resource_group_name, cluster_name, '', True) + configuration_protected_settings[self.AZURE_LOG_ANALYTICS_CONNECTION_STRING] = shared_key + logger.info("Get log analytics connection string succeeded.") + except azure.core.exceptions.HttpResponseError: + logger.info("Failed to get log analytics connection string.") + + if self.RELAY_SERVER_CONNECTION_STRING not in configuration_protected_settings: + try: + relay_connection_string, _, _ = _get_relay_connection_str( + cmd, subscription_id, resource_group_name, cluster_name, '', self.RELAY_HC_AUTH_NAME, True) + configuration_protected_settings[self.RELAY_SERVER_CONNECTION_STRING] = relay_connection_string + logger.info("Get relay connection string succeeded.") + except azure.mgmt.relay.models.ErrorResponseException as ex: + if ex.response.status_code == 404: + raise ResourceNotFoundError("Relay server not found.") from ex + raise AzureResponseError("Failed to get relay connection string.") from ex + + if self.SERVICE_BUS_CONNECTION_STRING not in configuration_protected_settings: + try: + service_bus_connection_string, _ = _get_service_bus_connection_string( + cmd, subscription_id, resource_group_name, cluster_name, '', {}, True) + configuration_protected_settings[self.SERVICE_BUS_CONNECTION_STRING] = service_bus_connection_string + logger.info("Get service bus connection string succeeded.") + except azure.core.exceptions.HttpResponseError as ex: + if ex.response.status_code == 404: + raise ResourceNotFoundError("Service bus not found.") from ex + raise AzureResponseError("Failed to get service bus connection string.") from ex + + configuration_protected_settings = _dereference(self.reference_mapping, configuration_protected_settings) + + if self.sslKeyPemFile in configuration_protected_settings and \ + self.sslCertPemFile in configuration_protected_settings: + logger.info(f"Both {self.sslKeyPemFile} and {self.sslCertPemFile} are set, update ssl key.") + self.__set_inference_ssl_from_file(configuration_protected_settings) + + return PatchExtension(auto_upgrade_minor_version=auto_upgrade_minor_version, + release_train=release_train, + version=version, + configuration_settings=configuration_settings, + configuration_protected_settings=configuration_protected_settings) + + def __normalize_config(self, configuration_settings, configuration_protected_settings): + # inference + isTestCluster = _get_value_from_config_protected_config( + self.inferenceLoadBalancerHA, configuration_settings, configuration_protected_settings) + if isTestCluster is not None: + isTestCluster = str(isTestCluster).lower() == 'false' + if isTestCluster: + configuration_settings['clusterPurpose'] = 'DevTest' + else: + configuration_settings['clusterPurpose'] = 'FastProd' + + feIsNodePort = _get_value_from_config_protected_config( + self.privateEndpointNodeport, configuration_settings, configuration_protected_settings) + if feIsNodePort is not None: + feIsNodePort = str(feIsNodePort).lower() == 'true' + configuration_settings['scoringFe.serviceType.nodePort'] = feIsNodePort + + feIsInternalLoadBalancer = _get_value_from_config_protected_config( + self.privateEndpointILB, configuration_settings, configuration_protected_settings) + if feIsInternalLoadBalancer is not None: + feIsInternalLoadBalancer = str(feIsInternalLoadBalancer).lower() == 'true' + configuration_settings['scoringFe.serviceType.internalLoadBalancer'] = feIsInternalLoadBalancer + logger.warning( + 'Internal load balancer only supported on AKS and AKS Engine Clusters.') + def __validate_config(self, configuration_settings, configuration_protected_settings): # perform basic validation of the input config config_keys = configuration_settings.keys() @@ -241,24 +317,27 @@ def __validate_scoring_fe_settings(self, configuration_settings, configuration_p logger.warning( 'Internal load balancer only supported on AKS and AKS Engine Clusters.') + def __set_inference_ssl_from_file(self, configuration_protected_settings): + import base64 + feSslCertFile = configuration_protected_settings.get(self.sslCertPemFile) + feSslKeyFile = configuration_protected_settings.get(self.sslKeyPemFile) + with open(feSslCertFile) as f: + cert_data = f.read() + cert_data_bytes = cert_data.encode("ascii") + ssl_cert = base64.b64encode(cert_data_bytes).decode() + configuration_protected_settings['scoringFe.sslCert'] = ssl_cert + with open(feSslKeyFile) as f: + key_data = f.read() + key_data_bytes = key_data.encode("ascii") + ssl_key = base64.b64encode(key_data_bytes).decode() + configuration_protected_settings['scoringFe.sslKey'] = ssl_key + def __set_up_inference_ssl(self, configuration_settings, configuration_protected_settings): allowInsecureConnections = _get_value_from_config_protected_config( self.allowInsecureConnections, configuration_settings, configuration_protected_settings) allowInsecureConnections = str(allowInsecureConnections).lower() == 'true' if not allowInsecureConnections: - import base64 - feSslCertFile = configuration_protected_settings.get(self.sslCertPemFile) - feSslKeyFile = configuration_protected_settings.get(self.sslKeyPemFile) - with open(feSslCertFile) as f: - cert_data = f.read() - cert_data_bytes = cert_data.encode("ascii") - ssl_cert = base64.b64encode(cert_data_bytes).decode() - configuration_protected_settings['scoringFe.sslCert'] = ssl_cert - with open(feSslKeyFile) as f: - key_data = f.read() - key_data_bytes = key_data.encode("ascii") - ssl_key = base64.b64encode(key_data_bytes).decode() - configuration_protected_settings['scoringFe.sslKey'] = ssl_key + self.__set_inference_ssl_from_file(configuration_protected_settings) else: logger.warning( 'SSL is not enabled. Allowing insecure connections to the deployed services.') @@ -335,47 +414,46 @@ def _lock_resource(cmd, lock_scope, lock_level='CanNotDelete'): def _get_relay_connection_str( - cmd, subscription_id, resource_group_name, cluster_name, cluster_location, auth_rule_name) -> Tuple[str, str, str]: + cmd, subscription_id, resource_group_name, cluster_name, cluster_location, auth_rule_name, get_key_only=False) -> Tuple[str, str, str]: relay_client: azure.mgmt.relay.RelayManagementClient = get_mgmt_service_client( cmd.cli_ctx, azure.mgmt.relay.RelayManagementClient) cluster_id = '{}-{}-{}-relay'.format(cluster_name, subscription_id, resource_group_name) - # create namespace relay_namespace_name = _get_valid_name( cluster_id, suffix_len=6, max_len=50) - relay_namespace_params = azure.mgmt.relay.models.RelayNamespace( - location=cluster_location, tags=resource_tag) - - async_poller = relay_client.namespaces.create_or_update( - resource_group_name, relay_namespace_name, relay_namespace_params) - while True: - async_poller.result(15) - if async_poller.done(): - break - - # create hybrid connection hybrid_connection_name = cluster_name - hybrid_connection_object = relay_client.hybrid_connections.create_or_update( - resource_group_name, relay_namespace_name, hybrid_connection_name, requires_client_authorization=True) - - # relay_namespace_ojbect = relay_client.namespaces.get(resource_group_name, relay_namespace_name) - # relay_namespace_resource_id = relay_namespace_ojbect.id - # _lock_resource(cmd, lock_scope=relay_namespace_resource_id) - - # create authorization rule - auth_rule_rights = [azure.mgmt.relay.models.AccessRights.manage, - azure.mgmt.relay.models.AccessRights.send, azure.mgmt.relay.models.AccessRights.listen] - relay_client.hybrid_connections.create_or_update_authorization_rule( - resource_group_name, relay_namespace_name, hybrid_connection_name, auth_rule_name, rights=auth_rule_rights) + hc_resource_id = '' + if not get_key_only: + # create namespace + relay_namespace_params = azure.mgmt.relay.models.RelayNamespace( + location=cluster_location, tags=resource_tag) + + async_poller = relay_client.namespaces.create_or_update( + resource_group_name, relay_namespace_name, relay_namespace_params) + while True: + async_poller.result(15) + if async_poller.done(): + break + + # create hybrid connection + hybrid_connection_object = relay_client.hybrid_connections.create_or_update( + resource_group_name, relay_namespace_name, hybrid_connection_name, requires_client_authorization=True) + hc_resource_id = hybrid_connection_object.id + + # create authorization rule + auth_rule_rights = [azure.mgmt.relay.models.AccessRights.manage, + azure.mgmt.relay.models.AccessRights.send, azure.mgmt.relay.models.AccessRights.listen] + relay_client.hybrid_connections.create_or_update_authorization_rule( + resource_group_name, relay_namespace_name, hybrid_connection_name, auth_rule_name, rights=auth_rule_rights) # get connection string key: azure.mgmt.relay.models.AccessKeys = relay_client.hybrid_connections.list_keys( resource_group_name, relay_namespace_name, hybrid_connection_name, auth_rule_name) - return f'{key.primary_connection_string}', hybrid_connection_object.id, hybrid_connection_name + return f'{key.primary_connection_string}', hc_resource_id, hybrid_connection_name def _get_service_bus_connection_string(cmd, subscription_id, resource_group_name, cluster_name, cluster_location, - topic_sub_mapping: Dict[str, str]) -> Tuple[str, str]: + topic_sub_mapping: Dict[str, str], get_key_only=False) -> Tuple[str, str]: service_bus_client: azure.mgmt.servicebus.ServiceBusManagementClient = get_mgmt_service_client( cmd.cli_ctx, azure.mgmt.servicebus.ServiceBusManagementClient) cluster_id = '{}-{}-{}-service-bus'.format(cluster_name, @@ -383,35 +461,35 @@ def _get_service_bus_connection_string(cmd, subscription_id, resource_group_name service_bus_namespace_name = _get_valid_name( cluster_id, suffix_len=6, max_len=50) - # create namespace - service_bus_sku = azure.mgmt.servicebus.models.SBSku( - name=azure.mgmt.servicebus.models.SkuName.standard.name) - service_bus_namespace = azure.mgmt.servicebus.models.SBNamespace( - location=cluster_location, - sku=service_bus_sku, - tags=resource_tag) - async_poller = service_bus_client.namespaces.begin_create_or_update( - resource_group_name, service_bus_namespace_name, service_bus_namespace) - while True: - async_poller.result(15) - if async_poller.done(): - break - - for topic_name, service_bus_subscription_name in topic_sub_mapping.items(): - # create topic - topic = azure.mgmt.servicebus.models.SBTopic(max_size_in_megabytes=5120, default_message_time_to_live='P60D') - service_bus_client.topics.create_or_update( - resource_group_name, service_bus_namespace_name, topic_name, topic) - - # create subscription - sub = azure.mgmt.servicebus.models.SBSubscription( - max_delivery_count=1, default_message_time_to_live='P14D', lock_duration='PT30S') - service_bus_client.subscriptions.create_or_update( - resource_group_name, service_bus_namespace_name, topic_name, service_bus_subscription_name, sub) + if not get_key_only: + # create namespace + service_bus_sku = azure.mgmt.servicebus.models.SBSku( + name=azure.mgmt.servicebus.models.SkuName.standard.name) + service_bus_namespace = azure.mgmt.servicebus.models.SBNamespace( + location=cluster_location, + sku=service_bus_sku, + tags=resource_tag) + async_poller = service_bus_client.namespaces.begin_create_or_update( + resource_group_name, service_bus_namespace_name, service_bus_namespace) + while True: + async_poller.result(15) + if async_poller.done(): + break + + for topic_name, service_bus_subscription_name in topic_sub_mapping.items(): + # create topic + topic = azure.mgmt.servicebus.models.SBTopic(max_size_in_megabytes=5120, default_message_time_to_live='P60D') + service_bus_client.topics.create_or_update( + resource_group_name, service_bus_namespace_name, topic_name, topic) + + # create subscription + sub = azure.mgmt.servicebus.models.SBSubscription( + max_delivery_count=1, default_message_time_to_live='P14D', lock_duration='PT30S') + service_bus_client.subscriptions.create_or_update( + resource_group_name, service_bus_namespace_name, topic_name, service_bus_subscription_name, sub) service_bus_object = service_bus_client.namespaces.get(resource_group_name, service_bus_namespace_name) service_bus_resource_id = service_bus_object.id - # _lock_resource(cmd, service_bus_resource_id) # get connection string auth_rules = service_bus_client.namespaces.list_authorization_rules( @@ -423,26 +501,23 @@ def _get_service_bus_connection_string(cmd, subscription_id, resource_group_name def _get_log_analytics_ws_connection_string( - cmd, subscription_id, resource_group_name, cluster_name, cluster_location) -> Tuple[str, str]: + cmd, subscription_id, resource_group_name, cluster_name, cluster_location, get_key_only=False) -> Tuple[str, str]: log_analytics_ws_client: azure.mgmt.loganalytics.LogAnalyticsManagementClient = get_mgmt_service_client( cmd.cli_ctx, azure.mgmt.loganalytics.LogAnalyticsManagementClient) # create workspace cluster_id = '{}-{}-{}'.format(cluster_name, subscription_id, resource_group_name) log_analytics_ws_name = _get_valid_name(cluster_id, suffix_len=6, max_len=63) - log_analytics_ws = azure.mgmt.loganalytics.models.Workspace(location=cluster_location, tags=resource_tag) - async_poller = log_analytics_ws_client.workspaces.begin_create_or_update( - resource_group_name, log_analytics_ws_name, log_analytics_ws) customer_id = '' - # log_analytics_ws_resource_id = '' - while True: - log_analytics_ws_object = async_poller.result(15) - if async_poller.done(): - customer_id = log_analytics_ws_object.customer_id - # log_analytics_ws_resource_id = log_analytics_ws_object.id - break - - # _lock_resource(cmd, log_analytics_ws_resource_id) + if not get_key_only: + log_analytics_ws = azure.mgmt.loganalytics.models.Workspace(location=cluster_location, tags=resource_tag) + async_poller = log_analytics_ws_client.workspaces.begin_create_or_update( + resource_group_name, log_analytics_ws_name, log_analytics_ws) + while True: + log_analytics_ws_object = async_poller.result(15) + if async_poller.done(): + customer_id = log_analytics_ws_object.customer_id + break # get workspace shared keys shared_key = log_analytics_ws_client.shared_keys.get_shared_keys( diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py index e38f0d6e37c..5b76e500635 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py @@ -45,7 +45,7 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t ) return extension, name, create_identity - def Update(self, auto_upgrade_minor_version, release_train, version, configuration_settings, + def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings, configuration_protected_settings): """Default validations & defaults for Update Must create and return a valid 'PatchExtension' object. diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py index 0e56c203a91..33c8f683591 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py @@ -18,7 +18,7 @@ def Create(self, cmd, client, resource_group_name: str, cluster_name: str, name: pass @abstractmethod - def Update(self, auto_upgrade_minor_version: bool, release_train: str, version: str, + def Update(self, cmd, resource_group_name: str, cluster_name: str, auto_upgrade_minor_version: bool, release_train: str, version: str, configuration_settings: dict, configuration_protected_settings: dict) -> PatchExtension: pass diff --git a/src/k8s-extension/setup.py b/src/k8s-extension/setup.py index 99ef128f8ab..40946af2ee3 100644 --- a/src/k8s-extension/setup.py +++ b/src/k8s-extension/setup.py @@ -32,7 +32,7 @@ # TODO: Add any additional SDK dependencies here DEPENDENCIES = [] -VERSION = "1.0.0" +VERSION = "1.0.1" with open('README.rst', 'r', encoding='utf-8') as f: README = f.read()