Skip to content

Commit

Permalink
[ACS] Update service principal creation so that it is subscription sp…
Browse files Browse the repository at this point in the history
…ecific. (#1630)

Add unit tests for load/store.
  • Loading branch information
brendandburns authored and derekbekoe committed Jan 6, 2017
1 parent 69d026a commit 6d6f73f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ def acs_create(resource_group_name, deployment_name, name, ssh_key_value, dns_na
if raw=true
:raises: :class:`CloudError<msrestazure.azure_exceptions.CloudError>`
"""
subscription_id = _get_subscription_id()
if not dns_name_prefix:
# Use subscription id to provide uniqueness and prevent DNS name clashes
subscription_id = _get_subscription_id()
dns_name_prefix = '{}-{}-{}'.format(name, resource_group_name, subscription_id[0:6])

register_providers()
Expand All @@ -377,7 +377,7 @@ def acs_create(resource_group_name, deployment_name, name, ssh_key_value, dns_na
client = _graph_client_factory()
if not service_principal:
# --service-principal not specified, try to load it from local disk
principalObj = load_acs_service_principal()
principalObj = load_acs_service_principal(subscription_id)
if principalObj:
service_principal = principalObj.get('service_principal')
client_secret = principalObj.get('client_secret')
Expand All @@ -391,7 +391,7 @@ def acs_create(resource_group_name, deployment_name, name, ssh_key_value, dns_na

service_principal = _build_service_principal(client, name, url, client_secret)
logger.info('Created a service principal: %s', service_principal)
store_acs_service_principal(client_secret, service_principal)
store_acs_service_principal(subscription_id, client_secret, service_principal)
# Either way, update the role assignment, this fixes things if we fail part-way through
_add_role_assignment('Owner', service_principal)
else:
Expand All @@ -404,25 +404,35 @@ def acs_create(resource_group_name, deployment_name, name, ssh_key_value, dns_na
ops = get_mgmt_service_client(ACSClient).acs
return ops.create_or_update(resource_group_name, deployment_name, dns_name_prefix, name, ssh_key_value, content_version=content_version, admin_username=admin_username, agent_count=agent_count, agent_vm_size=agent_vm_size, location=location, master_count=master_count, orchestrator_type=orchestrator_type, tags=tags, custom_headers=custom_headers, raw=raw, operation_config=operation_config)

def store_acs_service_principal(client_secret, service_principal):
def store_acs_service_principal(subscription_id, client_secret, service_principal, config_path=os.path.join(get_config_dir(), 'acsServicePrincipal.json')):
obj = {}
if client_secret:
obj['client_secret'] = client_secret
if service_principal:
obj['service_principal'] = service_principal

configPath = os.path.join(get_config_dir(), 'acsServicePrincipal.json')
with os.fdopen(os.open(configPath, os.O_RDWR|os.O_CREAT|os.O_TRUNC, 0o600),
fullConfig = load_acs_service_principals(config_path=config_path)
if not fullConfig:
fullConfig = {}
fullConfig[subscription_id] = obj

with os.fdopen(os.open(config_path, os.O_RDWR|os.O_CREAT|os.O_TRUNC, 0o600),
'w+') as spFile:
json.dump(obj, spFile)
json.dump(fullConfig, spFile)

def load_acs_service_principal(subscription_id, config_path=os.path.join(get_config_dir(), 'acsServicePrincipal.json')):
config = load_acs_service_principals(config_path)
if not config:
return None
return config.get(subscription_id)

def load_acs_service_principal():
configPath = os.path.join(get_config_dir(), 'acsServicePrincipal.json')
if not os.path.exists(configPath):
def load_acs_service_principals(config_path):
if not os.path.exists(config_path):
return None
fd = os.open(configPath, os.O_RDONLY)
fd = os.open(config_path, os.O_RDONLY)
try:
return json.loads(os.fdopen(fd).read())
with os.fdopen(fd) as f:
return json.loads(f.read())
except: #pylint: disable=bare-except
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,73 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import unittest
import mock
import os
import tempfile
import unittest

from azure.cli.core._util import CLIError
from azure.cli.command_modules.acs.custom import _validate_service_principal
from azure.cli.command_modules.acs.custom import _validate_service_principal, load_acs_service_principal, store_acs_service_principal

class AcsServicePrincipalTest(unittest.TestCase):
def test_load_non_existent_service_principal(self):
principal = load_acs_service_principal('some-id', config_path='non-existent-file.json')
self.assertIsNone(principal)

def test_round_trip_one_subscription(self):
store_file = tempfile.NamedTemporaryFile(delete=False)
store_file.close()

service_principal = '12345'
sub_id = '67890'
client_secret = 'foobar'

store_acs_service_principal(
sub_id, client_secret, service_principal, config_path=store_file.name)
obj = load_acs_service_principal(sub_id, config_path=store_file.name)

self.assertIsNotNone(obj)
self.assertEqual(obj.get('service_principal'), service_principal)
self.assertEqual(obj.get('client_secret'), client_secret)

os.remove(store_file.name)

def test_round_trip_multi_subscription(self):
store_file = tempfile.NamedTemporaryFile(delete=False)
store_file.close()

principals = [
('12345', '67890', 'foobar'),
('abcde', 'fghij', 'foobaz'),
]

# Store them all
for principal in principals:
store_acs_service_principal(
principal[0], principal[1], principal[2], config_path=store_file.name)

# Make sure it worked
for principal in principals:
obj = load_acs_service_principal(principal[0], config_path=store_file.name)
self.assertIsNotNone(obj, 'expected non-None for {}'.format(principal[0]))
self.assertEqual(obj.get('service_principal'), principal[2])
self.assertEqual(obj.get('client_secret'), principal[1])

# Change one
new_principal = 'foo'
new_secret = 'bar'
store_acs_service_principal(
principals[0][0], new_secret, new_principal, config_path=store_file.name)
obj = load_acs_service_principal(principals[0][0], config_path=store_file.name)
self.assertIsNotNone(obj, 'expected non-None for {}'.format(principals[0][0]))
self.assertEqual(obj.get('service_principal'), new_principal)
self.assertEqual(obj.get('client_secret'), new_secret)


os.remove(store_file.name)



def test_validate_service_principal_ok(self):
client = mock.MagicMock()
client.service_principals = mock.Mock()
Expand Down

0 comments on commit 6d6f73f

Please sign in to comment.