From fc1780090683a3bd7d336d3250826f85362964b4 Mon Sep 17 00:00:00 2001 From: Roland Dobai Date: Tue, 19 Jan 2021 19:29:53 +0100 Subject: [PATCH] espsecure: Don't allow empty output and re-writing input --- espsecure.py | 97 ++++++++++++++++++++++++++++++++++++++---- test/test_espsecure.py | 19 +++++++++ 2 files changed, 108 insertions(+), 8 deletions(-) diff --git a/espsecure.py b/espsecure.py index 1240ec5af5..3e6056d295 100755 --- a/espsecure.py +++ b/espsecure.py @@ -19,6 +19,7 @@ import argparse import hashlib +import operator import os import struct import sys @@ -38,6 +39,13 @@ import esptool +try: + _string_type = basestring +except NameError: + # this has to be done with exception in order to avoid flake8 error + # Python 3 + _string_type = str + def get_chunks(source, chunk_len): """ Returns an iterator over 'chunk_len' chunks of 'source' """ @@ -81,6 +89,9 @@ def digest_secure_bootloader(args): """ Calculate the digest of a bootloader image, in the same way the hardware secure boot engine would do so. Can be used with a pre-loaded key to update a secure bootloader. """ + _check_output_is_not_input(args.keyfile, args.output) + _check_output_is_not_input(args.image, args.output) + _check_output_is_not_input(args.iv, args.output) if args.iv is not None: print("WARNING: --iv argument is for TESTING PURPOSES ONLY") iv = args.iv.read(128) @@ -216,6 +227,8 @@ def _get_sbv2_rsa_primitives(public_key): def sign_data(args): + _check_output_is_not_input(args.keyfile, args.output) + _check_output_is_not_input(args.datafile, args.output) if args.version == '1': return sign_secure_boot_v1(args) elif args.version == '2': @@ -448,6 +461,7 @@ def verify_signature_v2(args): def extract_public_key(args): + _check_output_is_not_input(args.keyfile, args.public_keyfile) if args.version == "1": """ Load an ECDSA private key and extract the embedded public key as raw binary data. """ sk = _load_ecdsa_signing_key(args.keyfile) @@ -519,6 +533,7 @@ def _digest_rsa_public_key(keyfile): def digest_rsa_public_key(args): + _check_output_is_not_input(args.keyfile, args.output) public_key_digest = _digest_rsa_public_key(args.keyfile) with open(args.output, "wb") as f: print("Writing the public key digest of %s to %s." % (args.keyfile.name, args.output)) @@ -526,6 +541,7 @@ def digest_rsa_public_key(args): def digest_private_key(args): + _check_output_is_not_input(args.keyfile, args.digest_file) sk = _load_ecdsa_signing_key(args.keyfile) repr(sk.to_string()) digest = hashlib.sha256() @@ -783,6 +799,8 @@ def _split_blocks(text, block_len=16): def decrypt_flash_data(args): + _check_output_is_not_input(args.keyfile, args.output) + _check_output_is_not_input(args.encrypted_file, args.output) if args.aes_xts: return _flash_encryption_operation_aes_xts(args.output, args.encrypted_file, args.address, args.keyfile, True) else: @@ -790,12 +808,69 @@ def decrypt_flash_data(args): def encrypt_flash_data(args): + _check_output_is_not_input(args.keyfile, args.output) + _check_output_is_not_input(args.plaintext_file, args.output) if args.aes_xts: return _flash_encryption_operation_aes_xts(args.output, args.plaintext_file, args.address, args.keyfile, False) else: return _flash_encryption_operation_esp32(args.output, args.plaintext_file, args.address, args.keyfile, args.flash_crypt_conf, False) +def _samefile(p1, p2): + try: + return os.path.samefile(p1, p2) + except (OSError, AttributeError): + # AttributeError - Python 2.7 on Windows doesn't know os.path.samefile() + # OSError (FileNotFoundError under Python 3) + return os.path.normcase(os.path.normpath(p1)) == os.path.normcase(os.path.normpath(p2)) + + +def _check_output_is_not_input(input_file, output_file): + i = getattr(input_file, 'name', input_file) + o = getattr(output_file, 'name', output_file) + # i & o should be string containing the path to files if espsecure was invoked from command line + # i & o still can be something else when espsecure was imported and the functions used directly (e.g. io.BytesIO()) + check_f = _samefile if isinstance(i, _string_type) and isinstance(o, _string_type) else operator.eq + if check_f(i, o): + raise esptool.FatalError('The input "{}" and output "{}" should not be the same!'.format(i, o)) + + +class OutFileType(object): + """ + This class is a replacement of argparse.FileType('wb'). It doesn't create a file immediately but only during the + first write. This allows us to do some checking before, e.g. that we are not overwriting the input. + + argparse.FileType('w')('-') returns STDOUT but argparse.FileType('wb') is not. + + The file object is not closed on failure just like in the case of argparse.FileType('w'). + """ + def __init__(self): + self.path = None + self.file_obj = None + + def __call__(self, path): + self.path = path + return self + + def __repr__(self): + return '{}({})'.format(type(self).__name__, self.path) + + def write(self, payload): + if len(payload) > 0: + if not self.file_obj: + self.file_obj = open(self.path, 'wb') + self.file_obj.write(payload) + + def close(self): + if self.file_obj: + self.file_obj.close() + self.file_obj = None + + @property + def name(self): + return self.path + + def main(custom_commandline=None): """ Main function for espsecure @@ -849,7 +924,7 @@ def main(custom_commandline=None): p.add_argument('--version', '-v', help="Version of the secure boot signing scheme to use.", choices=["1", "2"], default="1") p.add_argument('--keyfile', '-k', help="Private key file (PEM format) to extract the public verification key from.", type=argparse.FileType('rb'), required=True) - p.add_argument('public_keyfile', help="File to save new public key into", type=argparse.FileType('wb')) + p.add_argument('public_keyfile', help="File to save new public key into", type=OutFileType()) p = subparsers.add_parser('digest_rsa_public_key', help='Generate an SHA-256 digest of the public key. ' 'This digest is burned into the eFuse and asserts the legitimacy of the public key for Secure boot v2.') @@ -866,19 +941,19 @@ def main(custom_commandline=None): required=True) p.add_argument('--keylen', '-l', help="Length of private key digest file to generate (in bits). 3/4 Coding Scheme requires 192 bit key.", choices=[192, 256], default=256, type=int) - p.add_argument('digest_file', help="File to write 32 byte digest into", type=argparse.FileType('wb')) + p.add_argument('digest_file', help="File to write 32 byte digest into", type=OutFileType()) p = subparsers.add_parser('generate_flash_encryption_key', help='Generate a development-use 32 byte flash encryption key with random data.') p.add_argument('--keylen', '-l', help="Length of private key digest file to generate (in bits). 3/4 Coding Scheme requires 192 bit key.", choices=[192, 256], default=256, type=int) - p.add_argument('key_file', help="File to write 24 or 32 byte digest into", type=argparse.FileType('wb')) + p.add_argument('key_file', help="File to write 24 or 32 byte digest into", type=OutFileType()) p = subparsers.add_parser('decrypt_flash_data', help='Decrypt some data read from encrypted flash (using known key)') p.add_argument('encrypted_file', help="File with encrypted flash contents", type=argparse.FileType('rb')) p.add_argument('--aes_xts', '-x', help="Decrypt data using AES-XTS as used on ESP32-S2 and ESP32-C3", action='store_true') p.add_argument('--keyfile', '-k', help="File with flash encryption key", type=argparse.FileType('rb'), required=True) - p.add_argument('--output', '-o', help="Output file for plaintext data.", type=argparse.FileType('wb'), + p.add_argument('--output', '-o', help="Output file for plaintext data.", type=OutFileType(), required=True) p.add_argument('--address', '-a', help="Address offset in flash that file was read from.", required=True, type=esptool.arg_auto_int) p.add_argument('--flash_crypt_conf', help="Override FLASH_CRYPT_CONF efuse value (default is 0XF).", required=False, default=0xF, type=esptool.arg_auto_int) @@ -887,7 +962,7 @@ def main(custom_commandline=None): p.add_argument('--aes_xts', '-x', help="Encrypt data using AES-XTS as used on ESP32-S2 and ESP32-C3", action='store_true') p.add_argument('--keyfile', '-k', help="File with flash encryption key", type=argparse.FileType('rb'), required=True) - p.add_argument('--output', '-o', help="Output file for encrypted data.", type=argparse.FileType('wb'), + p.add_argument('--output', '-o', help="Output file for encrypted data.", type=OutFileType(), required=True) p.add_argument('--address', '-a', help="Address offset in flash where file will be flashed.", required=True, type=esptool.arg_auto_int) p.add_argument('--flash_crypt_conf', help="Override FLASH_CRYPT_CONF efuse value (default is 0XF).", required=False, default=0xF, type=esptool.arg_auto_int) @@ -899,9 +974,15 @@ def main(custom_commandline=None): parser.print_help() parser.exit(1) - # each 'operation' is a module-level function of the same name - operation_func = globals()[args.operation] - operation_func(args) + try: + # each 'operation' is a module-level function of the same name + operation_func = globals()[args.operation] + operation_func(args) + finally: + for arg_name in vars(args): + obj = getattr(args, arg_name) + if isinstance(obj, OutFileType): + obj.close() def _main(): diff --git a/test/test_espsecure.py b/test/test_espsecure.py index c1e2ccdeb4..1d8b53915b 100755 --- a/test/test_espsecure.py +++ b/test/test_espsecure.py @@ -477,6 +477,25 @@ def test_padding(self): self.assertEqual(ciphertext_full_block.getvalue(), ciphertext.getvalue()) +class DigestTests(EspSecureTestCase): + + def test_digest_private_key(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + self.addCleanup(os.remove, f.name) + outfile_name = f.name + + self.run_espsecure('digest_private_key --keyfile secure_images/ecdsa_secure_boot_signing_key.pem {}'.format(outfile_name)) + + with open(outfile_name, 'rb') as f: + self.assertEqual(f.read(), binascii.unhexlify('7b7b53708fc89d5e0b2df2571fb8f9d778f61a422ff1101a22159c4b34aad0aa')) + + def test_digest_private_key_with_invalid_output(self): + fname = 'secure_images/ecdsa_secure_boot_signing_key.pem' + + with self.assertRaises(subprocess.CalledProcessError): + self.run_espsecure('digest_private_key --keyfile {} {}'.format(fname, fname)) + + if __name__ == '__main__': print("Running espsecure tests...") print("Using espsecure %s at %s" % (esptool.__version__, os.path.abspath(espsecure.__file__)))