Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yuyu3/fix upgrade public #85

Merged
merged 4 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/k8s-extension/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

Release History
===============

1.0.1
++++++++++++++++++
* Enable Microsoft.PolicyInsights extension type
* microsoft.azureml.kubernetes: Retrieve relay and service bus connection string when update the configuration protected settings of the extension.

1.0.0
++++++++++++++++++
Expand Down
2 changes: 1 addition & 1 deletion src/k8s-extension/azext_k8s_extension/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def update_k8s_extension(cmd, client, resource_group_name, cluster_name, name, c
# Get the extension class based on the extension type
extension_class = ExtensionFactory(extension_type_lower)

upd_extension = extension_class.Update(auto_upgrade_minor_version, release_train, version,
upd_extension = extension_class.Update(cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version,
config_settings, config_protected_settings)

return sdk_no_wait(no_wait, client.begin_update, resource_group_name, cluster_rp, cluster_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import azure.mgmt.storage.models
import azure.mgmt.loganalytics
import azure.mgmt.loganalytics.models
from azure.cli.core.azclierror import InvalidArgumentValueError, MutuallyExclusiveArgumentError
from azure.cli.core.azclierror import AzureResponseError, InvalidArgumentValueError, MutuallyExclusiveArgumentError, ResourceNotFoundError
from azure.cli.core.commands.client_factory import get_mgmt_service_client, get_subscription_id
from azure.mgmt.resource.locks.models import ManagementLockObject
from knack.log import get_logger
Expand All @@ -31,7 +31,8 @@
from ..vendored_sdks.models import (
Extension,
Scope,
ScopeCluster
ScopeCluster,
PatchExtension
)

logger = get_logger(__name__)
Expand All @@ -45,7 +46,6 @@ def __init__(self):
# constants for configuration settings.
self.DEFAULT_RELEASE_NAMESPACE = 'azureml'
self.RELAY_CONNECTION_STRING_KEY = 'relayserver.relayConnectionString'
self.RELAY_CONNECTION_STRING_DEPRECATED_KEY = 'RelayConnectionString' # for 3rd party deployment, will be deprecated
self.HC_RESOURCE_ID_KEY = 'relayserver.hybridConnectionResourceID'
self.RELAY_HC_NAME_KEY = 'relayserver.hybridConnectionName'
self.SERVICE_BUS_CONNECTION_STRING_KEY = 'servicebus.connectionString'
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self):

# reference mapping
self.reference_mapping = {
self.RELAY_SERVER_CONNECTION_STRING: [self.RELAY_CONNECTION_STRING_KEY, self.RELAY_CONNECTION_STRING_DEPRECATED_KEY],
self.RELAY_SERVER_CONNECTION_STRING: [self.RELAY_CONNECTION_STRING_KEY],
self.SERVICE_BUS_CONNECTION_STRING: [self.SERVICE_BUS_CONNECTION_STRING_KEY],
'cluster_name': ['clusterId', 'prometheus.prometheusSpec.externalLabels.cluster_name'],
}
Expand Down Expand Up @@ -164,6 +164,68 @@ def Delete(self, cmd, client, resource_group_name, cluster_name, name, cluster_t
"Please try to reinstall device plugins to fix this issue.")
user_confirmation_factory(cmd, yes)

def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings,
configuration_protected_settings):
self.__normalize_config(configuration_settings, configuration_protected_settings)

if len(configuration_protected_settings) > 0:
subscription_id = get_subscription_id(cmd.cli_ctx)

if self.AZURE_LOG_ANALYTICS_CONNECTION_STRING not in configuration_protected_settings:
try:
_, shared_key = _get_log_analytics_ws_connection_string(
cmd, subscription_id, resource_group_name, cluster_name, '', True)
configuration_protected_settings[self.AZURE_LOG_ANALYTICS_CONNECTION_STRING] = shared_key
logger.info("Get log analytics connection string succeeded.")
except azure.core.exceptions.HttpResponseError:
logger.info("Failed to get log analytics connection string.")

if self.RELAY_SERVER_CONNECTION_STRING not in configuration_protected_settings:
try:
relay_connection_string, _, _ = _get_relay_connection_str(
cmd, subscription_id, resource_group_name, cluster_name, '', self.RELAY_HC_AUTH_NAME, True)
configuration_protected_settings[self.RELAY_SERVER_CONNECTION_STRING] = relay_connection_string
logger.info("Get relay connection string succeeded.")
except azure.mgmt.relay.models.ErrorResponseException as ex:
if ex.response.status_code == 404:
raise ResourceNotFoundError("Relay server not found.") from ex
raise AzureResponseError("Failed to get relay connection string.") from ex

if self.SERVICE_BUS_CONNECTION_STRING not in configuration_protected_settings:
try:
service_bus_connection_string, _ = _get_service_bus_connection_string(
cmd, subscription_id, resource_group_name, cluster_name, '', {}, True)
configuration_protected_settings[self.SERVICE_BUS_CONNECTION_STRING] = service_bus_connection_string
logger.info("Get service bus connection string succeeded.")
except azure.core.exceptions.HttpResponseError as ex:
if ex.response.status_code == 404:
raise ResourceNotFoundError("Service bus not found.") from ex
raise AzureResponseError("Failed to get service bus connection string.") from ex

configuration_protected_settings = _dereference(self.reference_mapping, configuration_protected_settings)

if self.sslKeyPemFile in configuration_protected_settings and \
self.sslCertPemFile in configuration_protected_settings:
logger.info(f"Both {self.sslKeyPemFile} and {self.sslCertPemFile} are set, update ssl key.")
self.__set_inference_ssl_from_file(configuration_protected_settings)

return PatchExtension(auto_upgrade_minor_version=auto_upgrade_minor_version,
release_train=release_train,
version=version,
configuration_settings=configuration_settings,
configuration_protected_settings=configuration_protected_settings)

def __normalize_config(self, configuration_settings, configuration_protected_settings):
# inference
isTestCluster = _get_value_from_config_protected_config(
self.inferenceLoadBalancerHA, configuration_settings, configuration_protected_settings)
if isTestCluster is not None:
isTestCluster = str(isTestCluster).lower() == 'false'
if isTestCluster:
configuration_settings['clusterPurpose'] = 'DevTest'
else:
configuration_settings['clusterPurpose'] = 'FastProd'

def __validate_config(self, configuration_settings, configuration_protected_settings):
# perform basic validation of the input config
config_keys = configuration_settings.keys()
Expand Down Expand Up @@ -241,24 +303,27 @@ def __validate_scoring_fe_settings(self, configuration_settings, configuration_p
logger.warning(
'Internal load balancer only supported on AKS and AKS Engine Clusters.')

def __set_inference_ssl_from_file(self, configuration_protected_settings):
import base64
feSslCertFile = configuration_protected_settings.get(self.sslCertPemFile)
feSslKeyFile = configuration_protected_settings.get(self.sslKeyPemFile)
with open(feSslCertFile) as f:
cert_data = f.read()
cert_data_bytes = cert_data.encode("ascii")
ssl_cert = base64.b64encode(cert_data_bytes).decode()
configuration_protected_settings['scoringFe.sslCert'] = ssl_cert
with open(feSslKeyFile) as f:
key_data = f.read()
key_data_bytes = key_data.encode("ascii")
ssl_key = base64.b64encode(key_data_bytes).decode()
configuration_protected_settings['scoringFe.sslKey'] = ssl_key

def __set_up_inference_ssl(self, configuration_settings, configuration_protected_settings):
allowInsecureConnections = _get_value_from_config_protected_config(
self.allowInsecureConnections, configuration_settings, configuration_protected_settings)
allowInsecureConnections = str(allowInsecureConnections).lower() == 'true'
if not allowInsecureConnections:
import base64
feSslCertFile = configuration_protected_settings.get(self.sslCertPemFile)
feSslKeyFile = configuration_protected_settings.get(self.sslKeyPemFile)
with open(feSslCertFile) as f:
cert_data = f.read()
cert_data_bytes = cert_data.encode("ascii")
ssl_cert = base64.b64encode(cert_data_bytes).decode()
configuration_protected_settings['scoringFe.sslCert'] = ssl_cert
with open(feSslKeyFile) as f:
key_data = f.read()
key_data_bytes = key_data.encode("ascii")
ssl_key = base64.b64encode(key_data_bytes).decode()
configuration_protected_settings['scoringFe.sslKey'] = ssl_key
self.__set_inference_ssl_from_file(configuration_protected_settings)
else:
logger.warning(
'SSL is not enabled. Allowing insecure connections to the deployed services.')
Expand Down Expand Up @@ -335,83 +400,82 @@ def _lock_resource(cmd, lock_scope, lock_level='CanNotDelete'):


def _get_relay_connection_str(
cmd, subscription_id, resource_group_name, cluster_name, cluster_location, auth_rule_name) -> Tuple[str, str, str]:
cmd, subscription_id, resource_group_name, cluster_name, cluster_location, auth_rule_name, get_key_only=False) -> Tuple[str, str, str]:
relay_client: azure.mgmt.relay.RelayManagementClient = get_mgmt_service_client(
cmd.cli_ctx, azure.mgmt.relay.RelayManagementClient)

cluster_id = '{}-{}-{}-relay'.format(cluster_name, subscription_id, resource_group_name)
# create namespace
relay_namespace_name = _get_valid_name(
cluster_id, suffix_len=6, max_len=50)
relay_namespace_params = azure.mgmt.relay.models.RelayNamespace(
location=cluster_location, tags=resource_tag)

async_poller = relay_client.namespaces.create_or_update(
resource_group_name, relay_namespace_name, relay_namespace_params)
while True:
async_poller.result(15)
if async_poller.done():
break

# create hybrid connection
hybrid_connection_name = cluster_name
hybrid_connection_object = relay_client.hybrid_connections.create_or_update(
resource_group_name, relay_namespace_name, hybrid_connection_name, requires_client_authorization=True)

# relay_namespace_ojbect = relay_client.namespaces.get(resource_group_name, relay_namespace_name)
# relay_namespace_resource_id = relay_namespace_ojbect.id
# _lock_resource(cmd, lock_scope=relay_namespace_resource_id)

# create authorization rule
auth_rule_rights = [azure.mgmt.relay.models.AccessRights.manage,
azure.mgmt.relay.models.AccessRights.send, azure.mgmt.relay.models.AccessRights.listen]
relay_client.hybrid_connections.create_or_update_authorization_rule(
resource_group_name, relay_namespace_name, hybrid_connection_name, auth_rule_name, rights=auth_rule_rights)
hc_resource_id = ''
if not get_key_only:
# create namespace
relay_namespace_params = azure.mgmt.relay.models.RelayNamespace(
location=cluster_location, tags=resource_tag)

async_poller = relay_client.namespaces.create_or_update(
resource_group_name, relay_namespace_name, relay_namespace_params)
while True:
async_poller.result(15)
if async_poller.done():
break

# create hybrid connection
hybrid_connection_object = relay_client.hybrid_connections.create_or_update(
resource_group_name, relay_namespace_name, hybrid_connection_name, requires_client_authorization=True)
hc_resource_id = hybrid_connection_object.id

# create authorization rule
auth_rule_rights = [azure.mgmt.relay.models.AccessRights.manage,
azure.mgmt.relay.models.AccessRights.send, azure.mgmt.relay.models.AccessRights.listen]
relay_client.hybrid_connections.create_or_update_authorization_rule(
resource_group_name, relay_namespace_name, hybrid_connection_name, auth_rule_name, rights=auth_rule_rights)

# get connection string
key: azure.mgmt.relay.models.AccessKeys = relay_client.hybrid_connections.list_keys(
resource_group_name, relay_namespace_name, hybrid_connection_name, auth_rule_name)
return f'{key.primary_connection_string}', hybrid_connection_object.id, hybrid_connection_name
return f'{key.primary_connection_string}', hc_resource_id, hybrid_connection_name


def _get_service_bus_connection_string(cmd, subscription_id, resource_group_name, cluster_name, cluster_location,
topic_sub_mapping: Dict[str, str]) -> Tuple[str, str]:
topic_sub_mapping: Dict[str, str], get_key_only=False) -> Tuple[str, str]:
service_bus_client: azure.mgmt.servicebus.ServiceBusManagementClient = get_mgmt_service_client(
cmd.cli_ctx, azure.mgmt.servicebus.ServiceBusManagementClient)
cluster_id = '{}-{}-{}-service-bus'.format(cluster_name,
subscription_id, resource_group_name)
service_bus_namespace_name = _get_valid_name(
cluster_id, suffix_len=6, max_len=50)

# create namespace
service_bus_sku = azure.mgmt.servicebus.models.SBSku(
name=azure.mgmt.servicebus.models.SkuName.standard.name)
service_bus_namespace = azure.mgmt.servicebus.models.SBNamespace(
location=cluster_location,
sku=service_bus_sku,
tags=resource_tag)
async_poller = service_bus_client.namespaces.begin_create_or_update(
resource_group_name, service_bus_namespace_name, service_bus_namespace)
while True:
async_poller.result(15)
if async_poller.done():
break

for topic_name, service_bus_subscription_name in topic_sub_mapping.items():
# create topic
topic = azure.mgmt.servicebus.models.SBTopic(max_size_in_megabytes=5120, default_message_time_to_live='P60D')
service_bus_client.topics.create_or_update(
resource_group_name, service_bus_namespace_name, topic_name, topic)

# create subscription
sub = azure.mgmt.servicebus.models.SBSubscription(
max_delivery_count=1, default_message_time_to_live='P14D', lock_duration='PT30S')
service_bus_client.subscriptions.create_or_update(
resource_group_name, service_bus_namespace_name, topic_name, service_bus_subscription_name, sub)
if not get_key_only:
# create namespace
service_bus_sku = azure.mgmt.servicebus.models.SBSku(
name=azure.mgmt.servicebus.models.SkuName.standard.name)
service_bus_namespace = azure.mgmt.servicebus.models.SBNamespace(
location=cluster_location,
sku=service_bus_sku,
tags=resource_tag)
async_poller = service_bus_client.namespaces.begin_create_or_update(
resource_group_name, service_bus_namespace_name, service_bus_namespace)
while True:
async_poller.result(15)
if async_poller.done():
break

for topic_name, service_bus_subscription_name in topic_sub_mapping.items():
# create topic
topic = azure.mgmt.servicebus.models.SBTopic(max_size_in_megabytes=5120, default_message_time_to_live='P60D')
service_bus_client.topics.create_or_update(
resource_group_name, service_bus_namespace_name, topic_name, topic)

# create subscription
sub = azure.mgmt.servicebus.models.SBSubscription(
max_delivery_count=1, default_message_time_to_live='P14D', lock_duration='PT30S')
service_bus_client.subscriptions.create_or_update(
resource_group_name, service_bus_namespace_name, topic_name, service_bus_subscription_name, sub)

service_bus_object = service_bus_client.namespaces.get(resource_group_name, service_bus_namespace_name)
service_bus_resource_id = service_bus_object.id
# _lock_resource(cmd, service_bus_resource_id)

# get connection string
auth_rules = service_bus_client.namespaces.list_authorization_rules(
Expand All @@ -423,26 +487,23 @@ def _get_service_bus_connection_string(cmd, subscription_id, resource_group_name


def _get_log_analytics_ws_connection_string(
cmd, subscription_id, resource_group_name, cluster_name, cluster_location) -> Tuple[str, str]:
cmd, subscription_id, resource_group_name, cluster_name, cluster_location, get_key_only=False) -> Tuple[str, str]:
log_analytics_ws_client: azure.mgmt.loganalytics.LogAnalyticsManagementClient = get_mgmt_service_client(
cmd.cli_ctx, azure.mgmt.loganalytics.LogAnalyticsManagementClient)

# create workspace
cluster_id = '{}-{}-{}'.format(cluster_name, subscription_id, resource_group_name)
log_analytics_ws_name = _get_valid_name(cluster_id, suffix_len=6, max_len=63)
log_analytics_ws = azure.mgmt.loganalytics.models.Workspace(location=cluster_location, tags=resource_tag)
async_poller = log_analytics_ws_client.workspaces.begin_create_or_update(
resource_group_name, log_analytics_ws_name, log_analytics_ws)
customer_id = ''
# log_analytics_ws_resource_id = ''
while True:
log_analytics_ws_object = async_poller.result(15)
if async_poller.done():
customer_id = log_analytics_ws_object.customer_id
# log_analytics_ws_resource_id = log_analytics_ws_object.id
break

# _lock_resource(cmd, log_analytics_ws_resource_id)
if not get_key_only:
log_analytics_ws = azure.mgmt.loganalytics.models.Workspace(location=cluster_location, tags=resource_tag)
async_poller = log_analytics_ws_client.workspaces.begin_create_or_update(
resource_group_name, log_analytics_ws_name, log_analytics_ws)
while True:
log_analytics_ws_object = async_poller.result(15)
if async_poller.done():
customer_id = log_analytics_ws_object.customer_id
break

# get workspace shared keys
shared_key = log_analytics_ws_client.shared_keys.get_shared_keys(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t
)
return extension, name, create_identity

def Update(self, auto_upgrade_minor_version, release_train, version, configuration_settings,
def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings,
configuration_protected_settings):
"""Default validations & defaults for Update
Must create and return a valid 'PatchExtension' object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def Create(self, cmd, client, resource_group_name: str, cluster_name: str, name:
pass

@abstractmethod
def Update(self, auto_upgrade_minor_version: bool, release_train: str, version: str,
def Update(self, cmd, resource_group_name: str, cluster_name: str, auto_upgrade_minor_version: bool, release_train: str, version: str,
configuration_settings: dict, configuration_protected_settings: dict) -> PatchExtension:
pass

Expand Down