Skip to content

Commit e1d63f6

Browse files
committed
Add saving drift information to a file
Saving clock drift information and loading it on startup greatly improves synchronization speed.
1 parent 782c447 commit e1d63f6

File tree

1 file changed

+78
-16
lines changed

1 file changed

+78
-16
lines changed

ntpclient.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# ntpclient.py
22

3-
import sys
3+
import os
44
import usocket as socket
55
import ustruct as struct
66
from machine import RTC, Pin
@@ -13,15 +13,25 @@
1313
# (date(2000, 1, 1) - date(1970, 1, 1)).days * 24*60*60
1414
UNIX_DELTA = 946684800
1515

16-
MIN_POLL = 64
17-
MAX_POLL = 1024
16+
# Poll interval
17+
_MIN_POLL = 64 # never poll faster than every 32 seconds
18+
_MAX_POLL = 1024 # default maximum poll interval
19+
_POLL_INC_AT = 50 # increase interval when the delta per second
20+
# falls below this number of microseconds
21+
_POLL_DEC_AT = 200 # decrease interval when the delta per second
22+
# grows above this number
1823

19-
# Internally we use a struct tm based timestamp format, which is
20-
# a tuple composed of (sec, usec) based on epoch 2000-01-01.
24+
# Drift file configuration
25+
_DRIFT_FILE_VERSION = 1
26+
_DRIFT_NUM_MAX = 200 # Aggregate when we have this many samples
27+
_DRIFT_NUM_AVG = 100 # Aggregate down to this many and save drift file
2128

2229
# time_add_us() -
2330
# Adds a number of microseconds to a timestamp.
2431
# Returns a timestamp.
32+
#
33+
# Internally we use a struct tm based timestamp format, which is
34+
# a tuple composed of (sec, usec) based on epoch 2000-01-01.
2535
def time_add_us(ts, us):
2636
usec = ts[1] + us
2737
if usec < 0:
@@ -38,16 +48,16 @@ def time_diff_us(ts1, ts2):
3848
# ntpclient -
3949
# Class implementing the uasyncio based NTP client
4050
class ntpclient:
41-
def __init__(self, host = 'pool.ntp.org', poll = MAX_POLL,
51+
def __init__(self, host = 'pool.ntp.org', poll = _MAX_POLL,
4252
adj_interval = 2, debug = False,
43-
max_startup_delta = 1.0):
53+
max_startup_delta = 1.0, drift_file = None):
4454
self.host = host
4555
self.sock = None
4656
self.addr = None
4757
self.rstr = None
4858
self.wstr = None
4959
self.req_poll = poll
50-
self.poll = MIN_POLL
60+
self.poll = _MIN_POLL
5161
self.max_startup_delta = int(max_startup_delta * 1000000)
5262
self.rtc = RTC()
5363
self.last_delta = None
@@ -57,11 +67,53 @@ def __init__(self, host = 'pool.ntp.org', poll = MAX_POLL,
5767
self.adj_interval = adj_interval
5868
self.adj_sum = 0
5969
self.adj_num = 0
70+
self.drift_file = drift_file
6071
self.debug = debug
6172

6273
asyncio.create_task(self._poll_task())
6374
asyncio.create_task(self._adj_task())
6475

76+
def drift_save(self):
77+
# This is called every time we increase the polling interval
78+
# or aggregate the drift summary data.
79+
if self.drift_file is None:
80+
return
81+
82+
try:
83+
tmp = self.drift_file + '.tmp'
84+
with open(tmp, 'w') as fd:
85+
fd.write("version = {}\n".format(_DRIFT_FILE_VERSION))
86+
fd.write("drift_sum = {}\n".format(self.drift_sum))
87+
fd.write("drift_num = {}\n".format(self.drift_num))
88+
os.rename(tmp, self.drift_file)
89+
except Exception as ex:
90+
print("ntpclient: drift_save():", ex)
91+
if self.debug:
92+
print("ntpclient: saved {}".format(self.drift_file))
93+
94+
def drift_load(self):
95+
if self.drift_file is None:
96+
return
97+
98+
try:
99+
with open(self.drift_file, 'r') as fd:
100+
info = {}
101+
exec(fd.read(), globals(), info)
102+
if info['version'] > _DRIFT_FILE_VERSION:
103+
print("ntpclient: WARNING - drift file version is {} "
104+
"- expected {}".format(info['version'],
105+
_DRIFT_FILE_VERSION))
106+
self.drift_sum = info['drift_sum']
107+
self.drift_num = info['drift_num']
108+
except Exception as ex:
109+
print("ntpclient: drift_load():", ex)
110+
return
111+
if self.debug:
112+
print("ntpclient: loaded drift data {}/{}"
113+
" = {}".format(self.drift_sum, self.drift_num,
114+
self.drift_sum // self.drift_num))
115+
116+
65117
async def _poll_server(self):
66118
# We try to stay with the same server as long as possible. Only
67119
# lookup the address on startup or after errors.
@@ -120,6 +172,9 @@ async def _poll_server(self):
120172
return (delay, time_diff_us(tnow, t2), t2)
121173

122174
async def _poll_task(self):
175+
# Try loading an existing drift file
176+
self.drift_load()
177+
123178
# Try to get a first server reading
124179
while True:
125180
try:
@@ -181,7 +236,7 @@ async def _poll_task(self):
181236
self.sock.close()
182237
self.sock = None
183238
self.addr = None
184-
self.poll = MIN_POLL
239+
self.poll = _MIN_POLL
185240
continue
186241

187242
if self.last_delta is None:
@@ -195,9 +250,14 @@ async def _poll_task(self):
195250
drift = (self.adj_sum + corr) // self.adj_num
196251
self.drift_sum += drift
197252
self.drift_num += 1
198-
if self.drift_num >= 200:
199-
self.drift_sum = (self.drift_sum // self.drift_num) * 100
200-
self.drift_num = 100
253+
if self.drift_num >= _DRIFT_NUM_MAX:
254+
# When we have 200 samples we aggregate the data down to
255+
# 100 samples in order to give an actual change in the
256+
# drift a chance to change our average.
257+
self.drift_sum = (self.drift_sum // self.drift_num) \
258+
* _DRIFT_NUM_AVG
259+
self.drift_num = _DRIFT_NUM_AVG
260+
self.drift_save()
201261
if self.debug:
202262
print("ntpclient: drift average adjusted to {0}/{1}".format(
203263
self.drift_sum, self.drift_num))
@@ -209,16 +269,18 @@ async def _poll_task(self):
209269
# per adj_interval is below or above a certain threshold.
210270
# This means we poll less if we think we are close to
211271
# the server and more often while homing in.
272+
delta_per_sec = delta // self.adj_num // self.adj_interval
212273
if self.poll < self.req_poll and self.drift_num > 25:
213-
if abs(delta // self.adj_num) < 100:
274+
if abs(delta_per_sec) < _POLL_INC_AT:
214275
self.poll <<= 1
215-
elif self.poll > MIN_POLL:
216-
if abs(delta // self.adj_num) > 300:
276+
self.drift_save()
277+
elif self.poll > _MIN_POLL:
278+
if abs(delta_per_sec) > _POLL_DEC_AT:
217279
self.poll >>= 1
218280
if self.debug:
219281
print("ntpclient: state at", utime.localtime())
220282
print("ntpclient: delta:", delta,
221-
"per adj:", delta // self.adj_num)
283+
"per_sec:", delta_per_sec)
222284
print("ntpclient: drift_sum:", self.drift_sum,
223285
"num:", self.drift_num, "avg:", avg_drift)
224286
print("ntpclient: new adj_delta:", self.adj_delta,

0 commit comments

Comments
 (0)