diff --git a/src/command_modules/azure-cli-acs/azure/cli/command_modules/acs/custom.py b/src/command_modules/azure-cli-acs/azure/cli/command_modules/acs/custom.py index af1504e49d3..c21090f5bee 100644 --- a/src/command_modules/azure-cli-acs/azure/cli/command_modules/acs/custom.py +++ b/src/command_modules/azure-cli-acs/azure/cli/command_modules/acs/custom.py @@ -361,9 +361,9 @@ def acs_create(resource_group_name, deployment_name, name, ssh_key_value, dns_na if raw=true :raises: :class:`CloudError` """ + subscription_id = _get_subscription_id() if not dns_name_prefix: # Use subscription id to provide uniqueness and prevent DNS name clashes - subscription_id = _get_subscription_id() dns_name_prefix = '{}-{}-{}'.format(name, resource_group_name, subscription_id[0:6]) register_providers() @@ -377,7 +377,7 @@ def acs_create(resource_group_name, deployment_name, name, ssh_key_value, dns_na client = _graph_client_factory() if not service_principal: # --service-principal not specified, try to load it from local disk - principalObj = load_acs_service_principal() + principalObj = load_acs_service_principal(subscription_id) if principalObj: service_principal = principalObj.get('service_principal') client_secret = principalObj.get('client_secret') @@ -391,7 +391,7 @@ def acs_create(resource_group_name, deployment_name, name, ssh_key_value, dns_na service_principal = _build_service_principal(client, name, url, client_secret) logger.info('Created a service principal: %s', service_principal) - store_acs_service_principal(client_secret, service_principal) + store_acs_service_principal(subscription_id, client_secret, service_principal) # Either way, update the role assignment, this fixes things if we fail part-way through _add_role_assignment('Owner', service_principal) else: @@ -404,25 +404,35 @@ def acs_create(resource_group_name, deployment_name, name, ssh_key_value, dns_na ops = get_mgmt_service_client(ACSClient).acs return ops.create_or_update(resource_group_name, deployment_name, dns_name_prefix, name, ssh_key_value, content_version=content_version, admin_username=admin_username, agent_count=agent_count, agent_vm_size=agent_vm_size, location=location, master_count=master_count, orchestrator_type=orchestrator_type, tags=tags, custom_headers=custom_headers, raw=raw, operation_config=operation_config) -def store_acs_service_principal(client_secret, service_principal): +def store_acs_service_principal(subscription_id, client_secret, service_principal, config_path=os.path.join(get_config_dir(), 'acsServicePrincipal.json')): obj = {} if client_secret: obj['client_secret'] = client_secret if service_principal: obj['service_principal'] = service_principal - configPath = os.path.join(get_config_dir(), 'acsServicePrincipal.json') - with os.fdopen(os.open(configPath, os.O_RDWR|os.O_CREAT|os.O_TRUNC, 0o600), + fullConfig = load_acs_service_principals(config_path=config_path) + if not fullConfig: + fullConfig = {} + fullConfig[subscription_id] = obj + + with os.fdopen(os.open(config_path, os.O_RDWR|os.O_CREAT|os.O_TRUNC, 0o600), 'w+') as spFile: - json.dump(obj, spFile) + json.dump(fullConfig, spFile) + +def load_acs_service_principal(subscription_id, config_path=os.path.join(get_config_dir(), 'acsServicePrincipal.json')): + config = load_acs_service_principals(config_path) + if not config: + return None + return config.get(subscription_id) -def load_acs_service_principal(): - configPath = os.path.join(get_config_dir(), 'acsServicePrincipal.json') - if not os.path.exists(configPath): +def load_acs_service_principals(config_path): + if not os.path.exists(config_path): return None - fd = os.open(configPath, os.O_RDONLY) + fd = os.open(config_path, os.O_RDONLY) try: - return json.loads(os.fdopen(fd).read()) + with os.fdopen(fd) as f: + return json.loads(f.read()) except: #pylint: disable=bare-except return None diff --git a/src/command_modules/azure-cli-acs/azure/cli/command_modules/acs/tests/test_service_principals.py b/src/command_modules/azure-cli-acs/azure/cli/command_modules/acs/tests/test_service_principals.py index b5fc0b4b0d7..aaa623312ed 100644 --- a/src/command_modules/azure-cli-acs/azure/cli/command_modules/acs/tests/test_service_principals.py +++ b/src/command_modules/azure-cli-acs/azure/cli/command_modules/acs/tests/test_service_principals.py @@ -3,13 +3,73 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import unittest import mock +import os +import tempfile +import unittest from azure.cli.core._util import CLIError -from azure.cli.command_modules.acs.custom import _validate_service_principal +from azure.cli.command_modules.acs.custom import _validate_service_principal, load_acs_service_principal, store_acs_service_principal class AcsServicePrincipalTest(unittest.TestCase): + def test_load_non_existent_service_principal(self): + principal = load_acs_service_principal('some-id', config_path='non-existent-file.json') + self.assertIsNone(principal) + + def test_round_trip_one_subscription(self): + store_file = tempfile.NamedTemporaryFile(delete=False) + store_file.close() + + service_principal = '12345' + sub_id = '67890' + client_secret = 'foobar' + + store_acs_service_principal( + sub_id, client_secret, service_principal, config_path=store_file.name) + obj = load_acs_service_principal(sub_id, config_path=store_file.name) + + self.assertIsNotNone(obj) + self.assertEqual(obj.get('service_principal'), service_principal) + self.assertEqual(obj.get('client_secret'), client_secret) + + os.remove(store_file.name) + + def test_round_trip_multi_subscription(self): + store_file = tempfile.NamedTemporaryFile(delete=False) + store_file.close() + + principals = [ + ('12345', '67890', 'foobar'), + ('abcde', 'fghij', 'foobaz'), + ] + + # Store them all + for principal in principals: + store_acs_service_principal( + principal[0], principal[1], principal[2], config_path=store_file.name) + + # Make sure it worked + for principal in principals: + obj = load_acs_service_principal(principal[0], config_path=store_file.name) + self.assertIsNotNone(obj, 'expected non-None for {}'.format(principal[0])) + self.assertEqual(obj.get('service_principal'), principal[2]) + self.assertEqual(obj.get('client_secret'), principal[1]) + + # Change one + new_principal = 'foo' + new_secret = 'bar' + store_acs_service_principal( + principals[0][0], new_secret, new_principal, config_path=store_file.name) + obj = load_acs_service_principal(principals[0][0], config_path=store_file.name) + self.assertIsNotNone(obj, 'expected non-None for {}'.format(principals[0][0])) + self.assertEqual(obj.get('service_principal'), new_principal) + self.assertEqual(obj.get('client_secret'), new_secret) + + + os.remove(store_file.name) + + + def test_validate_service_principal_ok(self): client = mock.MagicMock() client.service_principals = mock.Mock()