Skip to content

Commit 3d15fa0

Browse files
committed
Fix an unstable output issue of periodic sequence due to precision error.
1 parent facdc4e commit 3d15fa0

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

sdks/python/apache_beam/ml/ts/util.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,19 @@ def __init__(
159159
self._data = data
160160
self._interval = interval
161161
self._repeat = repeat
162-
self._duration = len(self._data) * interval
162+
163+
# In `ImpulseSeqGenRestrictionProvider`, the total number of counts
164+
# (i.e. total_outputs) is computed by ceil((end - start) / interval),
165+
# where end is start + duration * interval.
166+
# Due to precision error of arithmetic operations, even if duration is set
167+
# to len(self._data), (end - start) / interval could be a little bit smaller
168+
# or bigger than len(self._data).
169+
# In case of being bigger, total_outputs would be len(self._data) + 1,
170+
# as the ceil() operation is used.
171+
# Assuming that the precision error is no bigger than 1%, by subtracting
172+
# a small amount, we ensure that the result after ceil is stable even if
173+
# the precision error is present.
174+
self._duration = len(self._data) * interval - 0.01 * interval
163175
self._max_duration = max_duration if max_duration is not None else float(
164176
"inf")
165177

sdks/python/apache_beam/ml/ts/util_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ def test_timestamped_value(self):
7676
self.assertGreaterEqual(end - start, 3)
7777
self.assertLessEqual(end - start, 7)
7878

79+
def test_stable_output(self):
80+
options = PipelineOptions()
81+
data = [(Timestamp(1), 1), (Timestamp(2), 2), (Timestamp(3), 3),
82+
(Timestamp(6), 6), (Timestamp(4), 4), (Timestamp(5), 5),
83+
(Timestamp(7), 7), (Timestamp(8), 8), (Timestamp(9), 9),
84+
(Timestamp(10), 10)]
85+
expected = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
86+
with beam.Pipeline(options=options) as p:
87+
ret = (p | PeriodicStream(data, interval=0.0001))
88+
assert_that(ret, equal_to(expected))
89+
7990

8091
if __name__ == '__main__':
8192
logging.getLogger().setLevel(logging.WARNING)

0 commit comments

Comments
 (0)