forked from Netnod/nts-poc-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ntske_server.py
executable file
·481 lines (383 loc) · 14.8 KB
/
ntske_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
#! /usr/bin/python3
from __future__ import division, print_function, unicode_literals
import os
import sys
import socket
import traceback
import binascii
import struct
import syslog
import signal
from socketserver import ThreadingTCPServer, TCPServer, BaseRequestHandler
from pooling import ThreadPoolTCPServer
from sslwrapper import SSLWrapper
from constants import *
from ntske_record import *
from nts import NTSCookie
from server_helper import ServerHelper
from util import hexlify
from threading import Timer
import logging
root = logging.getLogger()
root.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
root.addHandler(handler)
assert sys.version_info[0] == 3
DEBUG = 0
ALLOW_MULTIPLE = 0
# Protocol IDs, see the IANA Network Time Security Next Protocols registry
SUPPORTED_PROTOCOLS = {
0, # NTPv4
}
# Algorithm identifiers, see RFC5297
SUPPORTED_ALGORITHMS = {
15, # AEAD_AES_SIV_CMAC_256
}
def unpack_array(buf):
assert(len(buf) % 2 == 0)
fmt = '>%uH' % (len(buf) / 2)
return struct.unpack(fmt, buf)
def pack_array(a):
if len(a) == 0:
return b''
elif len(a) == 1:
return struct.pack('>H', a[0])
else:
return struct.pack('>%uH' % len(a), fmt, a)
def flatten(a):
return [ item for l in a for item in l ]
class NTSKEHandler(BaseRequestHandler):
def handle(self):
self.info = {
'site' : self.server.syslog,
'client_addr' : self.client_address[0],
'client_port' : self.client_address[1],
}
try:
status = self.handle2()
if not isinstance(status, ''.__class__) or (
status != 'success' and not status.startswith('invalid')):
status = 'invalid'
except:
status = 'exception'
raise
finally:
self.info['status'] = status
info = ' '.join([ '%s=%s' % (k,v)
for k,v in sorted(self.info.items()) ])
if 1:
print(info)
if self.server.syslog:
syslog.syslog(syslog.LOG_INFO | syslog.LOG_USER, info)
def recv_all(self, s, count):
buf = bytes()
while len(buf) < count:
data = s.recv(count - len(buf))
if not data:
raise IOError("short recv")
buf += data
return buf
def handle2(self):
# print("Handle", self.client_address, "in child", os.getpid())
self.keyid, self.key = self.server.helper.get_server_key()
s = self.server.wrapper.accept(self.request)
if not s:
return 'invalid_tls_failure'
if not self.server.helper.allow_any_alpn:
alpn_protocol = s.selected_alpn_protocol()
if alpn_protocol not in [NTS_ALPN_PROTO]:
return('invalid_alpn_protocol')
self.info.update(s.info())
if DEBUG >= 2:
print("keyid = unhexlify('''%s''')" % hexlify(self.keyid))
print("server_key = unhexlify('''%s''')" % hexlify(self.key))
self.npn_protocols = []
self.aead_algorithms = []
self.eom_received = False
self.errors = set()
self.warnings = set()
npn_ack = False
aead_ack = False
protocols = []
while True:
resp = self.recv_all(s, 4)
if resp is None:
return 'invalid_premature_eof'
return 1
if (len(resp) < 4):
print("Premature end of client request", file = sys.stderr)
return 'invalid_short_field'
body_len = struct.unpack(">H", resp[2:4])[0]
if body_len > 0:
resp += self.recv_all(s, body_len)
record = Record(resp)
self.process_record(record)
if record.rec_type == RT_END_OF_MESSAGE:
break
c2s_key = s.export_keying_material(self.server.key_label, NTS_TLS_Key_LEN, NTS_TLS_Key_C2S)
s2c_key = s.export_keying_material(self.server.key_label, NTS_TLS_Key_LEN, NTS_TLS_Key_S2C)
response = self.get_response(c2s_key, s2c_key)
s.sendall(b''.join(map(bytes, response)))
s.shutdown()
return 'success'
def error(self, code, message):
print("error %u: %s" % (code, message), file = sys.stderr)
self.errors.add(code)
if 0:
raise ValueError(message)
def warning(self, code, message):
print("warning %u: %s" % (code, message), file = sys.stderr)
self.warnings.add(code)
def notice(self, message):
print(message, file = sys.stderr)
def process_record(self, record):
if DEBUG >= 2:
print(record.critical, record.rec_type, record.body)
if self.eom_received:
self.error(ERR_BAD_REQUEST, "Records received after EOM")
return
if record.rec_type == RT_END_OF_MESSAGE:
if not record.critical:
self.error(ERR_BAD_REQUEST,
"EOM record MUST be criticial")
return
if len(record.body):
self.error(ERR_BAD_REQUEST,
"EOM record should have zero length body")
return
self.eom_received = True
elif record.rec_type == RT_NEXT_PROTO_NEG:
if self.npn_protocols:
if ALLOW_MULTIPLE:
self.notice("Multiple NPN records")
else:
self.error(ERR_BAD_REQUEST, "Multiple NPN record")
return
if not record.critical:
self.error(ERR_BAD_REQUEST, "NPN record MUST be criticial")
return
if len(record.body) % 2:
self.error(ERR_BAD_REQUEST,
"NPN record has invalid length")
return
if not len(record.body):
self.error(ERR_BAD_REQUEST,
"NPN record MUST specify at least one protocol")
self.npn_protocols.append(unpack_array(record.body))
elif record.rec_type == RT_AEAD_NEG:
if self.aead_algorithms:
if ALLOW_MULTIPLE:
self.notice("Multiple AEAD records")
else:
self.error(ERR_BAD_REQUEST, "Multiple AEAD records")
return
if len(record.body) % 2:
self.error(ERR_BAD_REQUEST,
"AEAD record has invalid length")
return
if not len(record.body):
self.error(ERR_BAD_REQUEST,
"AEAD record MUST specify at least one algorithm")
self.aead_algorithms.append(unpack_array(record.body))
elif record.rec_type == RT_ERROR:
self.error(ERR_BAD_REQUEST, "Received error record")
elif record.rec_type == RT_WARNING:
self.error(ERR_BAD_REQUEST, "Received warning record")
elif record.rec_type == RT_NEW_COOKIE:
self.error(ERR_BAD_REQUEST, "Received new cookie record")
else:
if record.critical:
self.error(ERR_UNREC_CRIT, "Received unknown critical record %u" % (
record.rec_type))
else:
self.notice("Received unknown record %u" % (record.rec_type))
def get_response(self, c2s_key, s2c_key):
protocols = []
if not self.npn_protocols:
self.error(ERR_BAD_REQUEST, "No NPN record received")
elif not flatten(self.npn_protocols):
pass
else:
for protocol in flatten(self.npn_protocols):
if protocol in SUPPORTED_PROTOCOLS:
protocols.append(protocol)
else:
self.notice("Unknown NPN %u" % protocol)
if not protocols:
self.error(ERR_BAD_REQUEST, "No supported NPN received")
algorithms = []
if not self.aead_algorithms:
self.error(ERR_BAD_REQUEST, "No AEAD record received")
elif not flatten(self.aead_algorithms):
pass
else:
for algorithm in flatten(self.aead_algorithms):
if algorithm in SUPPORTED_ALGORITHMS:
algorithms.append(algorithm)
else:
self.notice("Unknown AEAD algorithm %u" % algorithm)
if not algorithms:
self.error(ERR_BAD_REQUEST, "No supported AEAD algorithms received")
if not self.eom_received:
self.error(ERR_BAD_REQUEST, "No EOM record received")
records = []
for code in sorted(self.errors):
records.append(Record.make(True, RT_ERROR, struct.pack(">H", code)))
for code in sorted(self.warnings):
records.append(Record.make(True, RT_WARNING, struct.pack(">H", code)))
if protocols:
records.append(Record.make(True, RT_NEXT_PROTO_NEG,
struct.pack('>H', protocols[0])))
else:
records.append(Record.make(True, RT_NEXT_PROTO_NEG,
b''))
if algorithms:
aead_algo = algorithms[0]
records.append(Record.make(True, RT_AEAD_NEG,
struct.pack('>H', aead_algo)))
else:
records.append(Record.make(True, RT_AEAD_NEG, b''))
if self.errors:
records.append(Record.make(True, RT_END_OF_MESSAGE))
return records
if DEBUG >= 2:
print("c2s_key = unhexlify('''%s''')" % hexlify(c2s_key))
print("s2c_key = unhexlify('''%s''')" % hexlify(s2c_key))
if self.server.ntpv4_server is not None:
records.append(Record.make(True, RT_NTPV4_SERVER,
self.server.ntpv4_server))
if self.server.ntpv4_port is not None:
records.append(Record.make(True, RT_NTPV4_PORT,
struct.pack(">H", self.server.ntpv4_port)))
for i in range(8):
cookie = NTSCookie().pack(
self.keyid, self.key,
aead_algo, s2c_key, c2s_key)
records.append(Record.make(False, RT_NEW_COOKIE, cookie))
records.append(Record.make(True, RT_END_OF_MESSAGE))
return records
ChosenTCPServer = ThreadingTCPServer
ChosenTCPServer = ThreadPoolTCPServer
class NTSKEServer(ChosenTCPServer):
allow_reuse_address = True
address_family = socket.AF_INET6
request_queue_size = 200
def __init__(self, config_path):
self.helper = ServerHelper(config_path)
host = ''
port = int(self.helper.ntske_port)
super(NTSKEServer, self).__init__((host, port), NTSKEHandler)
self.ntpv4_server = self.helper.ntpv4_server
self.ntpv4_port = self.helper.ntpv4_port
self.key_label = self.helper.key_label
self.syslog = self.helper.syslog
if self.syslog:
syslog.openlog('ntske-server')
def serve_forever(self):
self.refresh_wrapper()
return super().serve_forever()
def sighup(self, signalnumber, frame):
print("pid %u received SIGHUP, refreshing" % os.getpid())
self.refresh()
def refresh_wrapper(self):
self.refresh()
t = Timer(60, self.refresh_wrapper)
t.daemon = True
t.start()
def refresh(self):
try:
wrapper = SSLWrapper()
if self.helper.allow_tlsv1_2:
print("Enabling TLSv1.2")
wrapper.enable_tlsv1_2()
wrapper.server(self.helper.ntske_server_cert,
self.helper.ntske_server_key)
wrapper.set_alpn_protocols([NTS_ALPN_PROTO])
self.wrapper = wrapper
except Exception:
traceback.print_exc()
try:
self.helper.load_server_keys()
except Exception:
traceback.print_exc()
def run_mgmt(host, port, server_keys_dir, parent_pid):
# Shut up flask about using test server
from flask import Flask
cli = sys.modules['flask.cli']
cli.show_server_banner = lambda *x: None
import mgmt
mgmt.server_keys_dir = server_keys_dir
mgmt.parent_pid = parent_pid
try:
mgmt.application.run(host = host, port = port)
except KeyboardInterrupt:
pass
print("mgmt", os.getpid(), "stopping...")
def main():
config_path = 'server.ini'
if len(sys.argv) > 2:
print("Usage: %s [server.ini]" % sys.argv[0], file = sys.stderr)
sys.exit(1)
if len(sys.argv) > 1:
config_path = sys.argv[1]
server = NTSKEServer(config_path)
pids = []
print("master process", os.getpid())
if 1:
for i in range(server.helper.processes - 1):
pid = os.fork()
if pid == 0:
print("child process", os.getpid())
try:
try:
signal.signal(signal.SIGHUP, server.sighup)
server.serve_forever()
except KeyboardInterrupt:
pass
print("child", os.getpid(), "stopping...")
server.server_close()
except:
import traceback
traceback.print_exc()
finally:
sys.exit(0)
else:
pids.append(pid)
if server.helper.mgmt_host and server.helper.mgmt_port:
print("starting mgmt web server on %s:%u"% (
server.helper.mgmt_host,
server.helper.mgmt_port))
parent_pid = os.getpid()
pid = os.fork()
if pid == 0:
print("mgmt process", os.getpid())
signal.signal(signal.SIGHUP, signal.SIG_IGN)
try:
run_mgmt(server.helper.mgmt_host, server.helper.mgmt_port,
server.helper.server_keys_dir, parent_pid)
except:
import traceback
traceback.print_exc()
finally:
sys.exit(0)
else:
pids.append(pid)
def master_sighup(signalnumber, frame, server = server, pids = pids):
server.sighup(signalnumber, frame)
for pid in pids:
os.kill(pid, signal.SIGHUP)
try:
signal.signal(signal.SIGHUP, master_sighup)
server.serve_forever()
except KeyboardInterrupt:
print("keyboardinterrupt")
print("shutting down...")
server.server_close()
for pid in pids:
p, status = os.wait()
print("child", p, "has stopped")
if __name__ == "__main__":
main()