1
1
# ntpclient.py
2
2
3
- import sys
3
+ import os
4
4
import usocket as socket
5
5
import ustruct as struct
6
6
from machine import RTC , Pin
13
13
# (date(2000, 1, 1) - date(1970, 1, 1)).days * 24*60*60
14
14
UNIX_DELTA = 946684800
15
15
16
- MIN_POLL = 64
17
- MAX_POLL = 1024
16
+ # Poll interval
17
+ _MIN_POLL = 64 # never poll faster than every 32 seconds
18
+ _MAX_POLL = 1024 # default maximum poll interval
19
+ _POLL_INC_AT = 50 # increase interval when the delta per second
20
+ # falls below this number of microseconds
21
+ _POLL_DEC_AT = 200 # decrease interval when the delta per second
22
+ # grows above this number
18
23
19
- # Internally we use a struct tm based timestamp format, which is
20
- # a tuple composed of (sec, usec) based on epoch 2000-01-01.
24
+ # Drift file configuration
25
+ _DRIFT_FILE_VERSION = 1
26
+ _DRIFT_NUM_MAX = 200 # Aggregate when we have this many samples
27
+ _DRIFT_NUM_AVG = 100 # Aggregate down to this many and save drift file
21
28
22
29
# time_add_us() -
23
30
# Adds a number of microseconds to a timestamp.
24
31
# Returns a timestamp.
32
+ #
33
+ # Internally we use a struct tm based timestamp format, which is
34
+ # a tuple composed of (sec, usec) based on epoch 2000-01-01.
25
35
def time_add_us (ts , us ):
26
36
usec = ts [1 ] + us
27
37
if usec < 0 :
@@ -38,16 +48,16 @@ def time_diff_us(ts1, ts2):
38
48
# ntpclient -
39
49
# Class implementing the uasyncio based NTP client
40
50
class ntpclient :
41
- def __init__ (self , host = 'pool.ntp.org' , poll = MAX_POLL ,
51
+ def __init__ (self , host = 'pool.ntp.org' , poll = _MAX_POLL ,
42
52
adj_interval = 2 , debug = False ,
43
- max_startup_delta = 1.0 ):
53
+ max_startup_delta = 1.0 , drift_file = None ):
44
54
self .host = host
45
55
self .sock = None
46
56
self .addr = None
47
57
self .rstr = None
48
58
self .wstr = None
49
59
self .req_poll = poll
50
- self .poll = MIN_POLL
60
+ self .poll = _MIN_POLL
51
61
self .max_startup_delta = int (max_startup_delta * 1000000 )
52
62
self .rtc = RTC ()
53
63
self .last_delta = None
@@ -57,11 +67,53 @@ def __init__(self, host = 'pool.ntp.org', poll = MAX_POLL,
57
67
self .adj_interval = adj_interval
58
68
self .adj_sum = 0
59
69
self .adj_num = 0
70
+ self .drift_file = drift_file
60
71
self .debug = debug
61
72
62
73
asyncio .create_task (self ._poll_task ())
63
74
asyncio .create_task (self ._adj_task ())
64
75
76
+ def drift_save (self ):
77
+ # This is called every time we increase the polling interval
78
+ # or aggregate the drift summary data.
79
+ if self .drift_file is None :
80
+ return
81
+
82
+ try :
83
+ tmp = self .drift_file + '.tmp'
84
+ with open (tmp , 'w' ) as fd :
85
+ fd .write ("version = {}\n " .format (_DRIFT_FILE_VERSION ))
86
+ fd .write ("drift_sum = {}\n " .format (self .drift_sum ))
87
+ fd .write ("drift_num = {}\n " .format (self .drift_num ))
88
+ os .rename (tmp , self .drift_file )
89
+ except Exception as ex :
90
+ print ("ntpclient: drift_save():" , ex )
91
+ if self .debug :
92
+ print ("ntpclient: saved {}" .format (self .drift_file ))
93
+
94
+ def drift_load (self ):
95
+ if self .drift_file is None :
96
+ return
97
+
98
+ try :
99
+ with open (self .drift_file , 'r' ) as fd :
100
+ info = {}
101
+ exec (fd .read (), globals (), info )
102
+ if info ['version' ] > _DRIFT_FILE_VERSION :
103
+ print ("ntpclient: WARNING - drift file version is {} "
104
+ "- expected {}" .format (info ['version' ],
105
+ _DRIFT_FILE_VERSION ))
106
+ self .drift_sum = info ['drift_sum' ]
107
+ self .drift_num = info ['drift_num' ]
108
+ except Exception as ex :
109
+ print ("ntpclient: drift_load():" , ex )
110
+ return
111
+ if self .debug :
112
+ print ("ntpclient: loaded drift data {}/{}"
113
+ " = {}" .format (self .drift_sum , self .drift_num ,
114
+ self .drift_sum // self .drift_num ))
115
+
116
+
65
117
async def _poll_server (self ):
66
118
# We try to stay with the same server as long as possible. Only
67
119
# lookup the address on startup or after errors.
@@ -120,6 +172,9 @@ async def _poll_server(self):
120
172
return (delay , time_diff_us (tnow , t2 ), t2 )
121
173
122
174
async def _poll_task (self ):
175
+ # Try loading an existing drift file
176
+ self .drift_load ()
177
+
123
178
# Try to get a first server reading
124
179
while True :
125
180
try :
@@ -181,7 +236,7 @@ async def _poll_task(self):
181
236
self .sock .close ()
182
237
self .sock = None
183
238
self .addr = None
184
- self .poll = MIN_POLL
239
+ self .poll = _MIN_POLL
185
240
continue
186
241
187
242
if self .last_delta is None :
@@ -195,9 +250,14 @@ async def _poll_task(self):
195
250
drift = (self .adj_sum + corr ) // self .adj_num
196
251
self .drift_sum += drift
197
252
self .drift_num += 1
198
- if self .drift_num >= 200 :
199
- self .drift_sum = (self .drift_sum // self .drift_num ) * 100
200
- self .drift_num = 100
253
+ if self .drift_num >= _DRIFT_NUM_MAX :
254
+ # When we have 200 samples we aggregate the data down to
255
+ # 100 samples in order to give an actual change in the
256
+ # drift a chance to change our average.
257
+ self .drift_sum = (self .drift_sum // self .drift_num ) \
258
+ * _DRIFT_NUM_AVG
259
+ self .drift_num = _DRIFT_NUM_AVG
260
+ self .drift_save ()
201
261
if self .debug :
202
262
print ("ntpclient: drift average adjusted to {0}/{1}" .format (
203
263
self .drift_sum , self .drift_num ))
@@ -209,16 +269,18 @@ async def _poll_task(self):
209
269
# per adj_interval is below or above a certain threshold.
210
270
# This means we poll less if we think we are close to
211
271
# the server and more often while homing in.
272
+ delta_per_sec = delta // self .adj_num // self .adj_interval
212
273
if self .poll < self .req_poll and self .drift_num > 25 :
213
- if abs (delta // self . adj_num ) < 100 :
274
+ if abs (delta_per_sec ) < _POLL_INC_AT :
214
275
self .poll <<= 1
215
- elif self .poll > MIN_POLL :
216
- if abs (delta // self .adj_num ) > 300 :
276
+ self .drift_save ()
277
+ elif self .poll > _MIN_POLL :
278
+ if abs (delta_per_sec ) > _POLL_DEC_AT :
217
279
self .poll >>= 1
218
280
if self .debug :
219
281
print ("ntpclient: state at" , utime .localtime ())
220
282
print ("ntpclient: delta:" , delta ,
221
- "per adj :" , delta // self . adj_num )
283
+ "per_sec :" , delta_per_sec )
222
284
print ("ntpclient: drift_sum:" , self .drift_sum ,
223
285
"num:" , self .drift_num , "avg:" , avg_drift )
224
286
print ("ntpclient: new adj_delta:" , self .adj_delta ,
0 commit comments