-
Notifications
You must be signed in to change notification settings - Fork 0
/
fault.py
executable file
·184 lines (148 loc) · 4.79 KB
/
fault.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python3
import binascii
import aes
import des
from aes_recover import recover_aes_key
from des_recover import recover_des_key, recover_initial_des_key
from utils import hex
import random
import sys
import struct
import argparse
def read_faults(fname, tr):
output = []
with open(fname, "r") as f:
for l in f.readlines():
output.append(tr(l.rstrip()))
return output
def artificial_aes():
PLAINTEXT = binascii.unhexlify("000102030405060708090a0b0c0d0e0f")
KEY = binascii.unhexlify("bf05bd81f5497eef74dae9478eead746")
print(f"Plaintext: {hex(PLAINTEXT)}")
cipher = aes.AES(KEY)
print("Round keys:")
for i, mat in enumerate(cipher._key_matrices):
print(f"[{i:02}]: {hex(mat)}")
print("----")
ref = cipher.encrypt_block(PLAINTEXT)
print(f"ref: {hex(ref)}")
outputs = []
# 8 faults are enough in this super simple case
for i in range(8):
faulted = cipher.encrypt_block_with_fault(
PLAINTEXT,
i % 4,
0,
i,
)
outputs.append(faulted)
print(f"{hex(faulted)}")
k = recover_aes_key(ref, outputs)
print(f"Recovered round 10 key: {hex(k)}")
assert hex(k) == hex(cipher._key_matrices[-1])
print(f"SUCCESS")
def real_aes(ref, faults):
k = recover_aes_key(ref, faults)
print(f"Recoved round 10 key: {hex(k)}")
def artificial_des():
PLAINTEXT = 0x0102030405060708
KEY = 0x1C8529CEA240AE4F
ref = des.DES(PLAINTEXT, KEY)
outputs = []
for i in range(32):
r = 15
print(f"Faulting at round {r}")
faulted = des.DES(PLAINTEXT, KEY, r, i % 32)
outputs.append(faulted)
k = recover_des_key(ref, outputs)
subkeys = des.keySchedule(KEY)
print(f"Recovered last round key: {hex(k)}")
assert subkeys[-1] == k
real_k = recover_initial_des_key(k, PLAINTEXT, ref)
print(f"Real key: {hex(real_k)}")
assert real_k == KEY
def real_des(plain, ref, faults):
k = recover_des_key(ref, faults)
print(f"Recovered round key: {hex(k)}")
if plain:
real_k = recover_initial_des_key(k, plain, ref)
if real_k is None:
print("Couldn't find the initial key")
else:
print(f"Real key: {hex(real_k)}")
else:
print("No plaintext supplied, can't recover the initial key (--plain)")
def main():
parser = argparse.ArgumentParser(description="AES and DES fault analysis")
parser.add_argument(
"--aes", help="AES fault injection", default=False, action="store_true"
)
parser.add_argument(
"--des", help="DES fault injection", default=False, action="store_true"
)
parser.add_argument(
"--faults",
help="File with one faulted output per line (hex)",
type=str,
)
parser.add_argument(
"--ref",
help="Non faulted output for the plaintext (hex)",
required=True,
type=str,
)
parser.add_argument(
"--plain",
help="Plaintext that produces the reference output (hex)",
type=str,
)
parser.add_argument(
"--reverse",
help="Reverse the key schedule from a final round key",
default=False,
action="store_true",
)
parser.add_argument(
"--final-key", help="The final round key for --reverse", type=str
)
args = parser.parse_args()
if not args.des and not args.aes:
print("--des or --aes are required.")
return
if args.des and args.aes:
print("Can't combine --aes and --des.")
return
if args.aes and args.reverse:
print("AES key schedule reverse not supported")
return
if not args.reverse and not args.faults:
print("--faults is required when not using --reverse")
return
if args.reverse:
if not args.final_key:
print("--reverse needs --final-key")
return
if args.des and not args.plain:
print("DES --reverse needs --plaintext")
return
if args.reverse:
if args.des:
plain = int(args.plain, 16)
k = int(args.final_key, 16)
ref = int(args.ref, 16)
init_k = recover_initial_des_key(k, plain, ref)
if init_k is not None:
print(f"Round 0 DES key: {hex(init_k)}")
else:
print("Couldn't find a matching DES round 0 key")
elif args.aes:
faults = read_faults(args.faults, lambda x: binascii.unhexlify(x))
ref = binascii.unhexlify(args.ref)
real_aes(ref, faults)
else:
faults = read_faults(args.faults, lambda x: int(x, 16))
ref = int(args.ref, 16)
plain = int(args.plain, 16) if args.plain else None
real_des(plain, ref, faults)
if __name__ == "__main__":
main()