diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py index 98eadb6e6a1..ac71251ecbd 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py @@ -80,7 +80,7 @@ def __init__(self): self.allowInsecureConnections = 'allowInsecureConnections' self.SSL_SECRET = 'sslSecret' self.SSL_Cname = 'sslCname' - + self.inferenceRouterServiceType = 'inferenceRouterServiceType' self.internalLoadBalancerProvider = 'internalLoadBalancerProvider' self.inferenceLoadBalancerHA = 'inferenceLoadBalancerHA' @@ -98,6 +98,8 @@ def __init__(self): 'cluster_name': ['clusterId', 'prometheus.prometheusSpec.externalLabels.cluster_name'], } + self.OPEN_SHIFT = 'openshift' + def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_type, extension_type, scope, auto_upgrade_minor_version, release_train, version, target_namespace, release_namespace, configuration_settings, configuration_protected_settings, @@ -124,6 +126,10 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t resource = resources.get_by_id( cluster_resource_id, parent_api_version) cluster_location = resource.location.lower() + if resource.properties['totalNodeCount'] == 1 or resource.properties['totalNodeCount'] == 2: + configuration_settings['clusterPurpose'] = 'DevTest' + if resource.properties['distribution'].lower() == "openshift": + configuration_settings[self.OPEN_SHIFT] = "true" except CloudError as ex: raise ex @@ -221,7 +227,7 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers if internalLoadBalancerProvider is not None: hasInternalLoadBalancerProvider = True messageBody = messageBody + "internalLoadBalancerProvider\n" - + sslCname = _get_value_from_config_protected_config(self.SSL_Cname, configuration_settings, configuration_protected_settings) if sslCname is not None: hasSslCname = True @@ -294,7 +300,11 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers if self.sslKeyPemFile in configuration_protected_settings and \ self.sslCertPemFile in configuration_protected_settings: logger.info(f"Both {self.sslKeyPemFile} and {self.sslCertPemFile} are set, update ssl key.") - self.__set_inference_ssl_from_file(configuration_protected_settings, self.sslCertPemFile, self.sslKeyPemFile) + fe_ssl_cert_file = configuration_protected_settings.get(self.sslCertPemFile) + fe_ssl_key_file = configuration_protected_settings.get(self.sslKeyPemFile) + + if fe_ssl_cert_file and fe_ssl_key_file: + self.__set_inference_ssl_from_file(configuration_protected_settings, fe_ssl_cert_file, fe_ssl_key_file) return PatchExtension(auto_upgrade_minor_version=auto_upgrade_minor_version, release_train=release_train, @@ -312,18 +322,18 @@ def __normalize_config(self, configuration_settings, configuration_protected_set configuration_settings['clusterPurpose'] = 'DevTest' else: configuration_settings['clusterPurpose'] = 'FastProd' - + inferenceRouterServiceType = _get_value_from_config_protected_config( self.inferenceRouterServiceType, configuration_settings, configuration_protected_settings) if inferenceRouterServiceType: if inferenceRouterServiceType.lower() != 'nodeport' and inferenceRouterServiceType.lower() != 'loadbalancer': raise InvalidArgumentValueError( - "inferenceRouterServiceType only supports nodePort or loadBalancer." - "Check https://aka.ms/arcmltsg for more information.") - + "inferenceRouterServiceType only supports nodePort or loadBalancer." + "Check https://aka.ms/arcmltsg for more information.") + feIsNodePort = str(inferenceRouterServiceType).lower() == 'nodeport' configuration_settings['scoringFe.serviceType.nodePort'] = feIsNodePort - + internalLoadBalancerProvider = _get_value_from_config_protected_config( self.internalLoadBalancerProvider, configuration_settings, configuration_protected_settings) if internalLoadBalancerProvider: @@ -395,23 +405,23 @@ def __validate_scoring_fe_settings(self, configuration_settings, configuration_p "either provide sslCertPemFile and sslKeyPemFile to --configuration-protected-settings, " f"or provide sslSecret(kubernetes secret name) in --configuration-settings containing both ssl cert and ssl key under {release_namespace} namespace. " "Otherwise, to enable HTTP endpoint, explicitly set allowInsecureConnections=true.") - + if sslEnabled: sslCname = _get_value_from_config_protected_config( self.SSL_Cname, configuration_settings, configuration_protected_settings) if not sslCname: raise InvalidArgumentValueError( - "To enable HTTPs endpoint, " - "please specify sslCname parameter in --configuration-settings. Check https://aka.ms/arcmltsg for more information.") - + "To enable HTTPs endpoint, " + "please specify sslCname parameter in --configuration-settings. Check https://aka.ms/arcmltsg for more information.") + inferenceRouterServiceType = _get_value_from_config_protected_config( self.inferenceRouterServiceType, configuration_settings, configuration_protected_settings) if not inferenceRouterServiceType or (inferenceRouterServiceType.lower() != 'nodeport' and inferenceRouterServiceType.lower() != 'loadbalancer'): raise InvalidArgumentValueError( "To use inference, " "please specify inferenceRouterServiceType=nodePort or inferenceRouterServiceType=loadBalancer in --configuration-settings and also set internalLoadBalancerProvider=azure if your aks only supports internal load balancer." - "Check https://aka.ms/arcmltsg for more information.") - + "Check https://aka.ms/arcmltsg for more information.") + feIsNodePort = str(inferenceRouterServiceType).lower() == 'nodeport' internalLoadBalancerProvider = _get_value_from_config_protected_config( self.internalLoadBalancerProvider, configuration_settings, configuration_protected_settings) diff --git a/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 b/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 index b1e4b3d1d39..8d50a93c455 100644 --- a/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 +++ b/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 @@ -13,7 +13,7 @@ Describe 'AzureML Kubernetes Testing' { It 'Creates the extension and checks that it onboards correctly with inference and SSL enabled' { $sslKeyPemFile = Join-Path (Join-Path (Join-Path (Split-Path $PSScriptRoot -Parent) "data") "azure_ml") "test_key.pem" $sslCertPemFile = Join-Path (Join-Path (Join-Path (Split-Path $PSScriptRoot -Parent) "data") "azure_ml") "test_cert.pem" - az $Env:K8sExtensionName create -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters --extension-type $extensionType -n $extensionName --release-train staging --config enableInference=true identity.proxy.remoteEnabled=True identity.proxy.remoteHost=https://master.experiments.azureml-test.net inferenceLoadBalancerHA=False --config-protected sslKeyPemFile=$sslKeyPemFile sslCertPemFile=$sslCertPemFile --no-wait + az $Env:K8sExtensionName create -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters --extension-type $extensionType -n $extensionName --release-train staging --config enableInference=true identity.proxy.remoteEnabled=True identity.proxy.remoteHost=https://master.experiments.azureml-test.net inferenceRouterServiceType=nodePort --config-protected sslKeyPemFile=$sslKeyPemFile sslCertPemFile=$sslCertPemFile --no-wait $? | Should -BeTrue $output = az $Env:K8sExtensionName show -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters -n $extensionName