Skip to content

Commit 94e799a

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #178 from IFCA/fix-CircularMean
Fix inefficient CircularMean implementation
2 parents f1c655e + 33f470b commit 94e799a

File tree

2 files changed

+41
-52
lines changed

2 files changed

+41
-52
lines changed

frouros/utils/data_structures.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Data structures module."""
22

3-
from typing import Optional, List, Union # noqa: TYP001
4-
3+
from typing import Any, Optional, List, Union
54

65
import numpy as np # type: ignore
76

@@ -115,20 +114,20 @@ def max_len(self, value: int) -> None:
115114
self._max_len = value
116115

117116
@property
118-
def queue(self) -> List[Optional[bool]]:
117+
def queue(self) -> List[Optional[Any]]:
119118
"""Queue property.
120119
121120
:return: queue
122-
:rtype: List[Optional[bool]]
121+
:rtype: List[Optional[Any]]
123122
"""
124123
return self._queue
125124

126125
@queue.setter
127-
def queue(self, value: List[Optional[bool]]) -> None:
126+
def queue(self, value: List[Optional[Any]]) -> None:
128127
"""Queue setter.
129128
130129
:param value: value to be set
131-
:type value: List[Optional[bool]]
130+
:type value: List[Optional[Any]]
132131
:raises ValueError: Value error exception
133132
"""
134133
if not isinstance(value, list):
@@ -151,10 +150,10 @@ def clear(self) -> None:
151150
self.last = -1
152151
self.queue = [None] * self.max_len
153152

154-
def dequeue(self) -> bool:
153+
def dequeue(self) -> Any:
155154
"""Dequeue oldest element.
156155
157-
:rtype: bool
156+
:rtype: value: Any
158157
:raises EmptyQueue: Empty queue error exception
159158
"""
160159
if self.is_empty():
@@ -164,17 +163,19 @@ def dequeue(self) -> bool:
164163
self.count -= 1
165164
return element # type: ignore
166165

167-
def enqueue(self, value: Union[np.ndarray, float]) -> None:
166+
def enqueue(self, value: Union[np.ndarray, int, float]) -> Optional[Any]:
168167
"""Enqueue element/s.
169168
170169
:param value: value to be enqueued
171170
:type value: Union[np.ndarray, float]
171+
:return element: dequeued element
172+
:rtype: Optional[Any]
172173
"""
173-
if self.is_full():
174-
_ = self.dequeue()
174+
element = self.dequeue() if self.is_full() else None
175175
self.last = (self.last + 1) % self.max_len
176176
self.queue[self.last] = value # type: ignore
177177
self.count += 1
178+
return element
178179

179180
def is_empty(self) -> bool:
180181
"""Check if queue is empty.
@@ -205,13 +206,13 @@ def maintain_last_element(self) -> None:
205206
self.first = self.last
206207
self.count = 1
207208

208-
def __getitem__(self, idx: int) -> float:
209+
def __getitem__(self, idx: int) -> Any:
209210
"""Get queue item by position.
210211
211212
:param idx: position index
212213
:type idx: int
213214
:return: queue item
214-
:rtype: float
215+
:rtype: Any
215216
"""
216217
return self.queue[idx] # type: ignore
217218

frouros/utils/stats.py

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -90,46 +90,33 @@ def update(self, value: Union[int, float]) -> None:
9090
if not isinstance(value, (int, float)):
9191
raise TypeError("value must be of type int or float.")
9292
self.num_values += 1
93-
self.mean += (value - self.mean) / self.num_values
93+
self.mean += self.incremental_op(
94+
value=value,
95+
element=self.mean,
96+
size=self.num_values,
97+
)
98+
99+
@staticmethod
100+
def incremental_op(
101+
value: Union[int, float],
102+
element: Union[int, float],
103+
size: int,
104+
) -> float:
105+
"""Incremental operation."""
106+
return (value - element) / size
94107

95108
def get(self) -> float:
96109
"""Get method."""
97110
return self.mean
98111

99112

100-
class CircularMean(IncrementalStat):
113+
class CircularMean(Mean):
101114
"""Circular mean class."""
102115

103116
def __init__(self, size: int) -> None:
104-
"""Init method.
105-
106-
:param size: size of the circular mean
107-
:type size: int
108-
"""
109-
self.size = size
110-
self.mean = 0.0
111-
self.queue = CircularQueue(max_len=self.size)
112-
113-
@property
114-
def size(self) -> int:
115-
"""Size property.
116-
117-
:return: size value
118-
:rtype: int
119-
"""
120-
return self._size
121-
122-
@size.setter
123-
def size(self, value: int) -> None:
124-
"""Size setter.
125-
126-
:param value: value to be set
127-
:type value: int
128-
:raises ValueError: Value error exception
129-
"""
130-
if value < 0:
131-
raise ValueError("size must be greater of equal than 0.")
132-
self._size = value
117+
"""Init method."""
118+
super().__init__()
119+
self.queue = CircularQueue(max_len=size)
133120

134121
def update(self, value: Union[int, float]) -> None:
135122
"""Update the mean value sequentially.
@@ -138,14 +125,15 @@ def update(self, value: Union[int, float]) -> None:
138125
:type value: int
139126
:raises TypeError: Type error exception
140127
"""
141-
# FIXME: Inefficient implementation # pylint: disable=fixme
142-
self.queue.enqueue(value=value)
143-
queue = np.array(self.queue.queue)
144-
self.mean = np.mean(queue[queue != np.array(None)])
145-
146-
def get(self) -> float:
147-
"""Get method."""
148-
return self.mean
128+
if not isinstance(value, (int, float)):
129+
raise TypeError("value must be of type int or float.")
130+
element = self.queue.enqueue(value=value)
131+
self.num_values = len(self.queue)
132+
self.mean += self.incremental_op(
133+
value=value,
134+
element=self.mean if element is None else element,
135+
size=self.num_values,
136+
)
149137

150138

151139
class EWMA(IncrementalStat):

0 commit comments

Comments
 (0)