Skip to content

Commit

Permalink
ACS: Generate SSH keys if required (Azure#1106)
Browse files Browse the repository at this point in the history
Generate SSH key pair if the key file is missing and users want to generate
  • Loading branch information
yugangw-msft authored Oct 22, 2016
1 parent a07d03f commit 02e7b7c
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 2 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ argcomplete==1.3.0
colorama==0.3.7
jmespath
mock==1.3.0
paramiko==2.0.2
pip
pygments==2.1.3
pylint==1.5.4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
from azure.cli.core.application import APPLICATION
from azure.cli.core.commands.parameters import get_one_of_subscription_locations
from azure.cli.core.commands.arm import resource_exists
import azure.cli.core._logging as _logging

from six.moves.urllib.request import urlopen #pylint: disable=import-error

from ._factory import _compute_client_factory
from ._vm_utils import read_content_if_is_file

logger = _logging.get_az_logger(__name__)

class VMImageFieldAction(argparse.Action): #pylint: disable=too-few-public-methods
def __call__(self, parser, namespace, values, option_string=None):
image = values
Expand Down Expand Up @@ -251,7 +254,61 @@ def _handle_container_ssh_file(**kwargs):
return

args = kwargs['args']

args.ssh_key_value = read_content_if_is_file(args.ssh_key_value)
string_or_file = args.ssh_key_value
content = string_or_file
if os.path.exists(string_or_file):
logger.info('Use existing SSH public key file: %s', string_or_file)
with open(string_or_file, 'r') as f:
content = f.read()
elif not _is_valid_ssh_rsa_public_key(content) and args.generate_ssh_keys:
#figure out appropriate file names:
#'base_name'(with private keys), and 'base_name.pub'(with public keys)
public_key_filepath = string_or_file
if public_key_filepath[-4:].lower() == '.pub':
private_key_filepath = public_key_filepath[:-4]
else:
private_key_filepath = public_key_filepath + '.private'
logger.warning('Creating SSH key files: %s,%s', private_key_filepath, public_key_filepath)
content = _generate_ssh_keys(private_key_filepath, public_key_filepath)
args.ssh_key_value = content

def _generate_ssh_keys(private_key_filepath, public_key_filepath):
import paramiko

ssh_dir, _ = os.path.split(private_key_filepath)
if not os.path.exists(ssh_dir):
os.makedirs(ssh_dir)
os.chmod(ssh_dir, 0o700)

key = paramiko.RSAKey.generate(2048)
key.write_private_key_file(private_key_filepath)
os.chmod(private_key_filepath, 0o600)

with open(public_key_filepath, 'w') as public_key_file:
public_key = '%s %s' % (key.get_name(), key.get_base64())
public_key_file.write(public_key)
os.chmod(public_key_filepath, 0o644)

return public_key

def _is_valid_ssh_rsa_public_key(openssh_pubkey):
#http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-validation-using-a-regular-expression #pylint: disable=line-too-long
#A "good enough" check is to see if the key starts with the correct header.
import struct
try:
from base64 import decodebytes as base64_decode
except ImportError:
#deprecated and redirected to decodebytes in Python 3
from base64 import decodestring as base64_decode
parts = openssh_pubkey.split()
if len(parts) < 2:
return False
key_type = parts[0]
key_string = parts[1]

data = base64_decode(key_string.encode())#pylint:disable=deprecated-method
int_len = 4
str_len = struct.unpack('>I', data[:int_len])[0] # this should return 7
return data[int_len:int_len+str_len] == key_type.encode()

APPLICATION.register(APPLICATION.COMMAND_PARSER_PARSED, _handle_container_ssh_file)
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_vm_size_completion_list(prefix, action, parsed_args, **kwargs):#pylint:
register_cli_argument('acs', 'orchestrator_type', **enum_choice_list(ContainerServiceOchestratorTypes))
register_cli_argument('acs', 'admin_username', admin_username_type)
register_cli_argument('acs', 'ssh_key_value', required=False, help='SSH key file value or key file path.', default=os.path.join(os.path.expanduser('~'), '.ssh', 'id_rsa.pub'), completer=FilesCompleter())
register_extra_cli_argument('acs create', 'generate_ssh_keys', action='store_true', help='Generate SSH public and private key files if missing')
register_cli_argument('acs', 'container_service_name', options_list=('--name', '-n'), help='The name of the container service', completer=get_resource_name_completion_list('Microsoft.ContainerService/ContainerServices'))
register_cli_argument('acs create', 'agent_vm_size', completer=get_vm_size_completion_list)
register_cli_argument('acs update', 'agent_count', type=int, help='The number of agents for the cluster')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#---------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
#---------------------------------------------------------------------------------------------

import os
import tempfile
import unittest
import mock

from azure.cli.command_modules.vm._actions import (_handle_container_ssh_file,
_is_valid_ssh_rsa_public_key)

class TestAcsActions(unittest.TestCase):
def test_generate_specfied_ssh_key_files(self):
_, private_key_file = tempfile.mkstemp()
public_key_file = private_key_file + '.pub'
args = mock.MagicMock()
args.ssh_key_value = public_key_file
args.generate_ssh_keys = True

#1 verify we generate key files if not existing
_handle_container_ssh_file(command='acs create', args=args)

generated_public_key_string = args.ssh_key_value
self.assertTrue(bool(args.ssh_key_value))
self.assertTrue(_is_valid_ssh_rsa_public_key(generated_public_key_string))
self.assertTrue(os.path.isfile(private_key_file))

#2 verify we load existing key files
# for convinience we will reuse the generated file in the previous step
args2 = mock.MagicMock()
args2.ssh_key_value = generated_public_key_string
args2.generate_ssh_keys = False
_handle_container_ssh_file(command='acs create', args=args2)
#we didn't regenerate
self.assertEqual(generated_public_key_string, args.ssh_key_value)

#3 verify we do not generate unless told so
_, private_key_file2 = tempfile.mkstemp()
public_key_file2 = private_key_file2 + '.pub'
args3 = mock.MagicMock()
args3.ssh_key_value = public_key_file2
args3.generate_ssh_keys = False
_handle_container_ssh_file(command='acs create', args=args3)
#still a file name
self.assertEqual(args3.ssh_key_value, public_key_file2)

#4 verify file naming if the pub file doesn't end with .pub
_, public_key_file4 = tempfile.mkstemp()
public_key_file4 += '1' #make it nonexisting
args4 = mock.MagicMock()
args4.ssh_key_value = public_key_file4
args4.generate_ssh_keys = True
_handle_container_ssh_file(command='acs create', args=args4)
self.assertTrue(os.path.isfile(public_key_file4 + '.private'))
self.assertTrue(os.path.isfile(public_key_file4))

1 change: 1 addition & 0 deletions src/command_modules/azure-cli-vm/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'azure-mgmt-resource==0.30.0rc6',
'azure-storage==0.33.0',
'azure-cli-core',
'paramiko'
]

with open('README.rst', 'r', encoding='utf-8') as f:
Expand Down

0 comments on commit 02e7b7c

Please sign in to comment.