diff --git a/src/k8s-extension/HISTORY.rst b/src/k8s-extension/HISTORY.rst index 2fe8a591fb9..4a96664d562 100644 --- a/src/k8s-extension/HISTORY.rst +++ b/src/k8s-extension/HISTORY.rst @@ -3,6 +3,11 @@ Release History =============== +0.4.3 +++++++++++++++++++ +* Add SSL support for AzureML + + 0.4.2 ++++++++++++++++++ diff --git a/src/k8s-extension/azext_k8s_extension/custom.py b/src/k8s-extension/azext_k8s_extension/custom.py index d07c0015792..504539179d1 100644 --- a/src/k8s-extension/azext_k8s_extension/custom.py +++ b/src/k8s-extension/azext_k8s_extension/custom.py @@ -8,8 +8,6 @@ import json from knack.log import get_logger -from msrestazure.azure_exceptions import CloudError - from azure.cli.core.azclierror import ResourceNotFoundError, MutuallyExclusiveArgumentError, \ InvalidArgumentValueError, CommandNotFoundError, RequiredArgumentMissingError from azure.cli.core.commands.client_factory import get_subscription_id diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py index 3a2544a4e3b..bd49a164a3b 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py @@ -4,6 +4,9 @@ # -------------------------------------------------------------------------------------------- # pylint: disable=unused-argument +# pylint: disable=line-too-long +# pylint: disable=too-many-locals + import copy from hashlib import md5 from typing import Any, Dict, List, Tuple @@ -17,9 +20,7 @@ import azure.mgmt.storage.models import azure.mgmt.loganalytics import azure.mgmt.loganalytics.models -from ..vendored_sdks.models import ( - ExtensionInstance, ExtensionInstanceUpdate, Scope, ScopeCluster) -from azure.cli.core.azclierror import InvalidArgumentValueError +from azure.cli.core.azclierror import InvalidArgumentValueError, MutuallyExclusiveArgumentError from azure.cli.core.commands.client_factory import get_mgmt_service_client, get_subscription_id from azure.mgmt.resource.locks.models import ManagementLockObject from knack.log import get_logger @@ -27,12 +28,19 @@ from .._client_factory import cf_resources from .PartnerExtensionModel import PartnerExtensionModel +from ..vendored_sdks.models import ( + ExtensionInstance, + ExtensionInstanceUpdate, + Scope, + ScopeCluster +) logger = get_logger(__name__) resource_tag = {'created_by': 'Azure Arc-enabled ML'} +# pylint: disable=too-many-instance-attributes class AzureMLKubernetes(PartnerExtensionModel): def __init__(self): # constants for configuration settings. @@ -66,6 +74,14 @@ def __init__(self): self.SERVICE_BUS_JOB_STATE_TOPIC = 'jobstate-updatedby-computeprovider' self.SERVICE_BUS_JOB_STATE_SUB = 'compute-scheduler-jobstate' + # constants for enabling SSL in inference + self.sslKeyPemFile = 'sslKeyPemFile' + self.sslCertPemFile = 'sslCertPemFile' + self.allowInsecureConnections = 'allowInsecureConnections' + self.privateEndpointILB = 'privateEndpointILB' + self.privateEndpointNodeport = 'privateEndpointNodeport' + self.inferenceLoadBalancerHA = 'inferenceLoadBalancerHA' + # reference mapping self.reference_mapping = { self.RELAY_SERVER_CONNECTION_STRING: [self.RELAY_CONNECTION_STRING_KEY, self.RELAY_CONNECTION_STRING_DEPRECATED_KEY], @@ -151,7 +167,7 @@ def __validate_config(self, configuration_settings, configuration_protected_sett config_keys = configuration_settings.keys() config_protected_keys = configuration_protected_settings.keys() dup_keys = set(config_keys) & set(config_protected_keys) - if len(dup_keys) > 0: + if dup_keys: for key in dup_keys: logger.warning( 'Duplicate keys found in both configuration settings and configuration protected setttings: %s', key) @@ -168,6 +184,7 @@ def __validate_config(self, configuration_settings, configuration_protected_sett if enable_inference: logger.warning("The installed AzureML extension for AML inference is experimental and not covered by customer support. Please use with discretion.") self.__validate_scoring_fe_settings(configuration_settings, configuration_protected_settings) + self.__set_up_inference_ssl(configuration_settings, configuration_protected_settings) elif not (enable_training or enable_inference): raise InvalidArgumentValueError( "Please create Microsoft.AzureML.Kubernetes extension instance either " @@ -181,33 +198,63 @@ def __validate_config(self, configuration_settings, configuration_protected_sett configuration_protected_settings.pop(self.ENABLE_INFERENCE, None) def __validate_scoring_fe_settings(self, configuration_settings, configuration_protected_settings): - clusterPurpose = _get_value_from_config_protected_config( - 'clusterPurpose', configuration_settings, configuration_protected_settings) - if clusterPurpose and clusterPurpose not in ["DevTest", "FastProd"]: - raise InvalidArgumentValueError( - "Accepted values for '--configuration-settings clusterPurpose' " - "are 'DevTest' and 'FastProd'") - - feSslCert = _get_value_from_config_protected_config( - 'scoringFe.sslCert', configuration_settings, configuration_protected_settings) - sslKey = _get_value_from_config_protected_config( - 'scoringFe.sslKey', configuration_settings, configuration_protected_settings) + isTestCluster = _get_value_from_config_protected_config( + self.inferenceLoadBalancerHA, configuration_settings, configuration_protected_settings) + isTestCluster = str(isTestCluster).lower() == 'false' + if isTestCluster: + configuration_settings['clusterPurpose'] = 'DevTest' + else: + configuration_settings['clusterPurpose'] = 'FastProd' + feSslCertFile = configuration_protected_settings.get(self.sslCertPemFile) + feSslKeyFile = configuration_protected_settings.get(self.sslKeyPemFile) allowInsecureConnections = _get_value_from_config_protected_config( - 'allowInsecureConnections', configuration_settings, configuration_protected_settings) + self.allowInsecureConnections, configuration_settings, configuration_protected_settings) allowInsecureConnections = str(allowInsecureConnections).lower() == 'true' - if (not feSslCert or not sslKey) and not allowInsecureConnections: + if (not feSslCertFile or not feSslKeyFile) and not allowInsecureConnections: raise InvalidArgumentValueError( "Provide ssl certificate and key. " "Otherwise explicitly allow insecure connection by specifying " "'--configuration-settings allowInsecureConnections=true'") + feIsNodePort = _get_value_from_config_protected_config( + self.privateEndpointNodeport, configuration_settings, configuration_protected_settings) + feIsNodePort = str(feIsNodePort).lower() == 'true' feIsInternalLoadBalancer = _get_value_from_config_protected_config( - 'scoringFe.serviceType.internalLoadBalancer', configuration_settings, configuration_protected_settings) + self.privateEndpointILB, configuration_settings, configuration_protected_settings) feIsInternalLoadBalancer = str(feIsInternalLoadBalancer).lower() == 'true' - if feIsInternalLoadBalancer: + + if feIsNodePort and feIsInternalLoadBalancer: + raise MutuallyExclusiveArgumentError( + "Specify either privateEndpointNodeport=true or privateEndpointILB=true, but not both.") + elif feIsNodePort: + configuration_settings['scoringFe.serviceType.nodePort'] = feIsNodePort + elif feIsInternalLoadBalancer: + configuration_settings['scoringFe.serviceType.internalLoadBalancer'] = feIsInternalLoadBalancer logger.warning( 'Internal load balancer only supported on AKS and AKS Engine Clusters.') + def __set_up_inference_ssl(self, configuration_settings, configuration_protected_settings): + allowInsecureConnections = _get_value_from_config_protected_config( + self.allowInsecureConnections, configuration_settings, configuration_protected_settings) + allowInsecureConnections = str(allowInsecureConnections).lower() == 'true' + if not allowInsecureConnections: + import base64 + feSslCertFile = configuration_protected_settings.get(self.sslCertPemFile) + feSslKeyFile = configuration_protected_settings.get(self.sslKeyPemFile) + with open(feSslCertFile) as f: + cert_data = f.read() + cert_data_bytes = cert_data.encode("ascii") + ssl_cert = base64.b64encode(cert_data_bytes).decode() + configuration_protected_settings['scoringFe.sslCert'] = ssl_cert + with open(feSslKeyFile) as f: + key_data = f.read() + key_data_bytes = key_data.encode("ascii") + ssl_key = base64.b64encode(key_data_bytes).decode() + configuration_protected_settings['scoringFe.sslKey'] = ssl_key + else: + logger.warning( + 'SSL is not enabled. Allowing insecure connections to the deployed services.') + def __create_required_resource( self, cmd, configuration_settings, configuration_protected_settings, subscription_id, resource_group_name, cluster_name, cluster_location): @@ -222,9 +269,8 @@ def __create_required_resource( configuration_settings[self.AZURE_LOG_ANALYTICS_CUSTOMER_ID_KEY] = ws_costumer_id configuration_protected_settings[self.AZURE_LOG_ANALYTICS_CONNECTION_STRING] = shared_key - if not configuration_settings.get( - self.RELAY_SERVER_CONNECTION_STRING) and not configuration_protected_settings.get( - self.RELAY_SERVER_CONNECTION_STRING): + if not configuration_settings.get(self.RELAY_SERVER_CONNECTION_STRING) and \ + not configuration_protected_settings.get(self.RELAY_SERVER_CONNECTION_STRING): logger.info('==== BEGIN RELAY CREATION ====') relay_connection_string, hc_resource_id, hc_name = _get_relay_connection_str( cmd, subscription_id, resource_group_name, cluster_name, cluster_location, self.RELAY_HC_AUTH_NAME) @@ -233,9 +279,8 @@ def __create_required_resource( configuration_settings[self.HC_RESOURCE_ID_KEY] = hc_resource_id configuration_settings[self.RELAY_HC_NAME_KEY] = hc_name - if not configuration_settings.get( - self.SERVICE_BUS_CONNECTION_STRING) and not configuration_protected_settings.get( - self.SERVICE_BUS_CONNECTION_STRING): + if not configuration_settings.get(self.SERVICE_BUS_CONNECTION_STRING) and \ + not configuration_protected_settings.get(self.SERVICE_BUS_CONNECTION_STRING): logger.info('==== BEGIN SERVICE BUS CREATION ====') topic_sub_mapping = { self.SERVICE_BUS_COMPUTE_STATE_TOPIC: self.SERVICE_BUS_COMPUTE_STATE_SUB, @@ -252,7 +297,7 @@ def __create_required_resource( def _get_valid_name(input_name: str, suffix_len: int, max_len: int) -> str: normalized_str = ''.join(filter(str.isalnum, input_name)) - assert len(normalized_str) > 0, "normalized name empty" + assert normalized_str, "normalized name empty" if len(normalized_str) <= max_len: return normalized_str @@ -267,6 +312,7 @@ def _get_valid_name(input_name: str, suffix_len: int, max_len: int) -> str: return new_name +# pylint: disable=broad-except def _lock_resource(cmd, lock_scope, lock_level='CanNotDelete'): lock_client: azure.mgmt.resource.locks.ManagementLockClient = get_mgmt_service_client( cmd.cli_ctx, azure.mgmt.resource.locks.ManagementLockClient) @@ -275,14 +321,13 @@ def _lock_resource(cmd, lock_scope, lock_level='CanNotDelete'): try: lock_client.management_locks.create_or_update_by_scope( scope=lock_scope, lock_name='amlarc-resource-lock', parameters=lock_object) - except: + except Exception: # try to lock the resource if user has the owner privilege pass def _get_relay_connection_str( - cmd, subscription_id, resource_group_name, cluster_name, cluster_location, auth_rule_name) -> Tuple[ - str, str, str]: + cmd, subscription_id, resource_group_name, cluster_name, cluster_location, auth_rule_name) -> Tuple[str, str, str]: relay_client: azure.mgmt.relay.RelayManagementClient = get_mgmt_service_client( cmd.cli_ctx, azure.mgmt.relay.RelayManagementClient) @@ -370,8 +415,7 @@ def _get_service_bus_connection_string(cmd, subscription_id, resource_group_name def _get_log_analytics_ws_connection_string( - cmd, subscription_id, resource_group_name, cluster_name, cluster_location) -> Tuple[ - str, str]: + cmd, subscription_id, resource_group_name, cluster_name, cluster_location) -> Tuple[str, str]: log_analytics_ws_client: azure.mgmt.loganalytics.LogAnalyticsManagementClient = get_mgmt_service_client( cmd.cli_ctx, azure.mgmt.loganalytics.LogAnalyticsManagementClient) diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/ContainerInsights.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/ContainerInsights.py index 3514122e391..e42f22199da 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/ContainerInsights.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/ContainerInsights.py @@ -14,7 +14,6 @@ from azure.cli.core.commands import LongRunningOperation from azure.cli.core.commands.client_factory import get_mgmt_service_client, get_subscription_id from azure.cli.core.util import sdk_no_wait -from msrestazure.azure_exceptions import CloudError from msrestazure.tools import parse_resource_id, is_valid_resource_id from ..vendored_sdks.models import ExtensionInstance @@ -104,8 +103,7 @@ def _invoke_deployment(cmd, resource_group_name, deployment_name, template, para if cmd.supported_api_version(min_api='2019-10-01', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES): validation_poller = smc.begin_validate(resource_group_name, deployment_name, deployment) return LongRunningOperation(cmd.cli_ctx)(validation_poller) - else: - return smc.validate(resource_group_name, deployment_name, deployment) + return smc.validate(resource_group_name, deployment_name, deployment) return sdk_no_wait(no_wait, smc.begin_create_or_update, resource_group_name, deployment_name, deployment) diff --git a/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/cert_and_key_encoded.txt b/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/cert_and_key_encoded.txt new file mode 100644 index 00000000000..4c2cb46c832 --- /dev/null +++ b/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/cert_and_key_encoded.txt @@ -0,0 +1,2 @@ +dGVzdGNlcnQ= +dGVzdGtleQ== \ No newline at end of file diff --git a/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/test_cert.pem b/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/test_cert.pem new file mode 100644 index 00000000000..e7529e3fdea --- /dev/null +++ b/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/test_cert.pem @@ -0,0 +1 @@ +testcert \ No newline at end of file diff --git a/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/test_key.pem b/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/test_key.pem new file mode 100644 index 00000000000..7ef00201c75 --- /dev/null +++ b/src/k8s-extension/azext_k8s_extension/tests/latest/data/azure_ml/test_key.pem @@ -0,0 +1 @@ +testkey \ No newline at end of file diff --git a/src/k8s-extension/azext_k8s_extension/tests/latest/test_azureml_extension.py b/src/k8s-extension/azext_k8s_extension/tests/latest/test_azureml_extension.py new file mode 100644 index 00000000000..8ddf4dfaef2 --- /dev/null +++ b/src/k8s-extension/azext_k8s_extension/tests/latest/test_azureml_extension.py @@ -0,0 +1,34 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=protected-access + +import os +import unittest + +from azext_k8s_extension.partner_extensions.AzureMLKubernetes import AzureMLKubernetes + + +TEST_DIR = os.path.abspath(os.path.join(os.path.abspath(__file__), '..')) + + +class TestAzureMlExtension(unittest.TestCase): + + def test_set_up_inference_ssl(self): + azremlk8sInstance = AzureMLKubernetes() + config = {'allowInsecureConnections': 'false'} + # read and encode dummy cert and key + sslKeyPemFile = os.path.join(TEST_DIR, 'data', 'azure_ml', 'test_key.pem') + sslCertPemFile = os.path.join(TEST_DIR, 'data', 'azure_ml', 'test_cert.pem') + protected_config = {'sslKeyPemFile': sslKeyPemFile, 'sslCertPemFile': sslCertPemFile} + azremlk8sInstance._AzureMLKubernetes__set_up_inference_ssl(config, protected_config) + self.assertTrue('scoringFe.sslCert' in protected_config) + self.assertTrue('scoringFe.sslKey' in protected_config) + encoded_cert_and_key_file = os.path.join(TEST_DIR, 'data', 'azure_ml', 'cert_and_key_encoded.txt') + with open(encoded_cert_and_key_file, "r") as text_file: + cert = text_file.readline().rstrip() + assert cert == protected_config['scoringFe.sslCert'] + key = text_file.readline() + assert key == protected_config['scoringFe.sslKey'] diff --git a/src/k8s-extension/azext_k8s_extension/tests/latest/test_k8s_extension_scenario.py b/src/k8s-extension/azext_k8s_extension/tests/latest/test_k8s_extension_scenario.py index 0e53c9e6691..010df2e3077 100644 --- a/src/k8s-extension/azext_k8s_extension/tests/latest/test_k8s_extension_scenario.py +++ b/src/k8s-extension/azext_k8s_extension/tests/latest/test_k8s_extension_scenario.py @@ -3,9 +3,9 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import os -import unittest +# pylint: disable=line-too-long +import os from azure.cli.testsdk import (ScenarioTest, ResourceGroupPreparer, record_only) @@ -27,13 +27,16 @@ def test_k8s_extension(self): 'version': '0.1.0' }) - self.cmd('k8s-extension create -g {rg} -n {name} -c {cluster_name} --cluster-type {cluster_type} --extension-type {extension_type} --release-train {release_train} --version {version}', checks=[ - self.check('name', '{name}'), - self.check('releaseTrain', '{release_train}'), - self.check('version', '{version}'), - self.check('resourceGroup', '{rg}'), - self.check('extensionType', '{extension_type}') - ]) + self.cmd('k8s-extension create -g {rg} -n {name} -c {cluster_name} --cluster-type {cluster_type} ' + '--extension-type {extension_type} --release-train {release_train} --version {version}', + checks=[ + self.check('name', '{name}'), + self.check('releaseTrain', '{release_train}'), + self.check('version', '{version}'), + self.check('resourceGroup', '{rg}'), + self.check('extensionType', '{extension_type}') + ] + ) # Update is disabled for now # self.cmd('k8s-extension update -g {rg} -n {name} --tags foo=boo', checks=[ diff --git a/src/k8s-extension/setup.py b/src/k8s-extension/setup.py index 908ecc53009..7937b085006 100644 --- a/src/k8s-extension/setup.py +++ b/src/k8s-extension/setup.py @@ -32,7 +32,7 @@ # TODO: Add any additional SDK dependencies here DEPENDENCIES = [] -VERSION = "0.4.2" +VERSION = "0.4.3" with open('README.rst', 'r', encoding='utf-8') as f: README = f.read()