diff --git a/requirements.txt b/requirements.txt index e0882c4adbc..d9af5a423a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ argcomplete==1.3.0 colorama==0.3.7 jmespath mock==1.3.0 +paramiko==2.0.2 pip pygments==2.1.3 pylint==1.5.4 diff --git a/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/_actions.py b/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/_actions.py index 86688f7a79c..78fb78969ee 100644 --- a/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/_actions.py +++ b/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/_actions.py @@ -12,12 +12,15 @@ from azure.cli.core.application import APPLICATION from azure.cli.core.commands.parameters import get_one_of_subscription_locations from azure.cli.core.commands.arm import resource_exists +import azure.cli.core._logging as _logging from six.moves.urllib.request import urlopen #pylint: disable=import-error from ._factory import _compute_client_factory from ._vm_utils import read_content_if_is_file +logger = _logging.get_az_logger(__name__) + class VMImageFieldAction(argparse.Action): #pylint: disable=too-few-public-methods def __call__(self, parser, namespace, values, option_string=None): image = values @@ -251,7 +254,61 @@ def _handle_container_ssh_file(**kwargs): return args = kwargs['args'] - - args.ssh_key_value = read_content_if_is_file(args.ssh_key_value) + string_or_file = args.ssh_key_value + content = string_or_file + if os.path.exists(string_or_file): + logger.info('Use existing SSH public key file: %s', string_or_file) + with open(string_or_file, 'r') as f: + content = f.read() + elif not _is_valid_ssh_rsa_public_key(content) and args.generate_ssh_keys: + #figure out appropriate file names: + #'base_name'(with private keys), and 'base_name.pub'(with public keys) + public_key_filepath = string_or_file + if public_key_filepath[-4:].lower() == '.pub': + private_key_filepath = public_key_filepath[:-4] + else: + private_key_filepath = public_key_filepath + '.private' + logger.warning('Creating SSH key files: %s,%s', private_key_filepath, public_key_filepath) + content = _generate_ssh_keys(private_key_filepath, public_key_filepath) + args.ssh_key_value = content + +def _generate_ssh_keys(private_key_filepath, public_key_filepath): + import paramiko + + ssh_dir, _ = os.path.split(private_key_filepath) + if not os.path.exists(ssh_dir): + os.makedirs(ssh_dir) + os.chmod(ssh_dir, 0o700) + + key = paramiko.RSAKey.generate(2048) + key.write_private_key_file(private_key_filepath) + os.chmod(private_key_filepath, 0o600) + + with open(public_key_filepath, 'w') as public_key_file: + public_key = '%s %s' % (key.get_name(), key.get_base64()) + public_key_file.write(public_key) + os.chmod(public_key_filepath, 0o644) + + return public_key + +def _is_valid_ssh_rsa_public_key(openssh_pubkey): + #http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-validation-using-a-regular-expression #pylint: disable=line-too-long + #A "good enough" check is to see if the key starts with the correct header. + import struct + try: + from base64 import decodebytes as base64_decode + except ImportError: + #deprecated and redirected to decodebytes in Python 3 + from base64 import decodestring as base64_decode + parts = openssh_pubkey.split() + if len(parts) < 2: + return False + key_type = parts[0] + key_string = parts[1] + + data = base64_decode(key_string.encode())#pylint:disable=deprecated-method + int_len = 4 + str_len = struct.unpack('>I', data[:int_len])[0] # this should return 7 + return data[int_len:int_len+str_len] == key_type.encode() APPLICATION.register(APPLICATION.COMMAND_PARSER_PARSED, _handle_container_ssh_file) diff --git a/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/_params.py b/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/_params.py index e58f3b94deb..c81422a3128 100644 --- a/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/_params.py +++ b/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/_params.py @@ -73,6 +73,7 @@ def get_vm_size_completion_list(prefix, action, parsed_args, **kwargs):#pylint: register_cli_argument('acs', 'orchestrator_type', **enum_choice_list(ContainerServiceOchestratorTypes)) register_cli_argument('acs', 'admin_username', admin_username_type) register_cli_argument('acs', 'ssh_key_value', required=False, help='SSH key file value or key file path.', default=os.path.join(os.path.expanduser('~'), '.ssh', 'id_rsa.pub'), completer=FilesCompleter()) +register_extra_cli_argument('acs create', 'generate_ssh_keys', action='store_true', help='Generate SSH public and private key files if missing') register_cli_argument('acs', 'container_service_name', options_list=('--name', '-n'), help='The name of the container service', completer=get_resource_name_completion_list('Microsoft.ContainerService/ContainerServices')) register_cli_argument('acs create', 'agent_vm_size', completer=get_vm_size_completion_list) register_cli_argument('acs update', 'agent_count', type=int, help='The number of agents for the cluster') diff --git a/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/tests/test_vm_actions.py b/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/tests/test_vm_actions.py new file mode 100644 index 00000000000..32273a2c103 --- /dev/null +++ b/src/command_modules/azure-cli-vm/azure/cli/command_modules/vm/tests/test_vm_actions.py @@ -0,0 +1,58 @@ +#--------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +#--------------------------------------------------------------------------------------------- + +import os +import tempfile +import unittest +import mock + +from azure.cli.command_modules.vm._actions import (_handle_container_ssh_file, + _is_valid_ssh_rsa_public_key) + +class TestAcsActions(unittest.TestCase): + def test_generate_specfied_ssh_key_files(self): + _, private_key_file = tempfile.mkstemp() + public_key_file = private_key_file + '.pub' + args = mock.MagicMock() + args.ssh_key_value = public_key_file + args.generate_ssh_keys = True + + #1 verify we generate key files if not existing + _handle_container_ssh_file(command='acs create', args=args) + + generated_public_key_string = args.ssh_key_value + self.assertTrue(bool(args.ssh_key_value)) + self.assertTrue(_is_valid_ssh_rsa_public_key(generated_public_key_string)) + self.assertTrue(os.path.isfile(private_key_file)) + + #2 verify we load existing key files + # for convinience we will reuse the generated file in the previous step + args2 = mock.MagicMock() + args2.ssh_key_value = generated_public_key_string + args2.generate_ssh_keys = False + _handle_container_ssh_file(command='acs create', args=args2) + #we didn't regenerate + self.assertEqual(generated_public_key_string, args.ssh_key_value) + + #3 verify we do not generate unless told so + _, private_key_file2 = tempfile.mkstemp() + public_key_file2 = private_key_file2 + '.pub' + args3 = mock.MagicMock() + args3.ssh_key_value = public_key_file2 + args3.generate_ssh_keys = False + _handle_container_ssh_file(command='acs create', args=args3) + #still a file name + self.assertEqual(args3.ssh_key_value, public_key_file2) + + #4 verify file naming if the pub file doesn't end with .pub + _, public_key_file4 = tempfile.mkstemp() + public_key_file4 += '1' #make it nonexisting + args4 = mock.MagicMock() + args4.ssh_key_value = public_key_file4 + args4.generate_ssh_keys = True + _handle_container_ssh_file(command='acs create', args=args4) + self.assertTrue(os.path.isfile(public_key_file4 + '.private')) + self.assertTrue(os.path.isfile(public_key_file4)) + diff --git a/src/command_modules/azure-cli-vm/setup.py b/src/command_modules/azure-cli-vm/setup.py index 0f5fd802015..ee89df22399 100644 --- a/src/command_modules/azure-cli-vm/setup.py +++ b/src/command_modules/azure-cli-vm/setup.py @@ -29,6 +29,7 @@ 'azure-mgmt-resource==0.30.0rc6', 'azure-storage==0.33.0', 'azure-cli-core', + 'paramiko' ] with open('README.rst', 'r', encoding='utf-8') as f: