Skip to content

Commit a965dc4

Browse files
tachyonicClockhmgomes
authored andcommitted
feat: add restart_stream=False as an option in evaluators
1 parent d82e05d commit a965dc4

File tree

6 files changed

+183
-46
lines changed

6 files changed

+183
-46
lines changed

src/capymoa/evaluation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
PredictionIntervalWindowedEvaluator,
1212
AnomalyDetectionEvaluator,
1313
)
14+
from . import results
1415

1516
__all__ = [
1617
"prequential_evaluation",
@@ -24,4 +25,5 @@
2425
"PredictionIntervalEvaluator",
2526
"PredictionIntervalWindowedEvaluator",
2627
"AnomalyDetectionEvaluator",
28+
"results"
2729
]

src/capymoa/evaluation/evaluation.py

Lines changed: 92 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Optional
2+
from typing import Any, Dict, Union
23

34
import pandas as pd
45
import numpy as np
@@ -9,14 +10,11 @@
910
import os
1011

1112
from capymoa.stream import Schema, Stream
12-
from capymoa.base import (
13-
AnomalyDetector,
14-
ClassifierSSL,
15-
MOAPredictionIntervalLearner
16-
)
13+
from capymoa.base import AnomalyDetector, ClassifierSSL, MOAPredictionIntervalLearner
1714

1815
from capymoa.evaluation.results import PrequentialResults
1916
from capymoa._utils import _translate_metric_name
17+
from capymoa.base import Classifier, Regressor
2018

2119
from com.yahoo.labs.samoa.instances import Instances, Attribute, DenseInstance
2220
from moa.core import InstanceExample
@@ -820,25 +818,51 @@ def stop_time_measuring(start_wallclock_time, start_cpu_time):
820818

821819

822820
def prequential_evaluation(
823-
stream,
824-
learner,
825-
max_instances=None,
826-
window_size=1000,
827-
store_predictions=False,
828-
store_y=False,
829-
optimise=True
830-
):
831-
"""
832-
Calculates the metrics cumulatively (i.e. test-then-train) and in a window-fashion (i.e. windowed prequential
833-
evaluation). Returns both evaluators so that the user has access to metrics from both evaluators.
821+
stream: Stream,
822+
learner: Union[Classifier, Regressor],
823+
max_instances: Optional[int] = None,
824+
window_size: int = 1000,
825+
store_predictions: bool = False,
826+
store_y: bool = False,
827+
optimise: bool = True,
828+
restart_stream: bool = True,
829+
) -> PrequentialResults:
830+
"""Run and evaluate a learner on a stream using prequential evaluation.
831+
832+
Calculates the metrics cumulatively (i.e. test-then-train) and in a
833+
window-fashion (i.e. windowed prequential evaluation). Returns both
834+
evaluators so that the user has access to metrics from both evaluators.
835+
836+
:param stream: A data stream to evaluate the learner on. Will be restarted if
837+
``restart_stream`` is True.
838+
:param learner: The learner to evaluate.
839+
:param max_instances: The number of instances to evaluate before exiting. If
840+
None, the evaluation will continue until the stream is empty.
841+
:param window_size: The size of the window used for windowed evaluation,
842+
defaults to 1000
843+
:param store_predictions: Store the learner's prediction in a list, defaults
844+
to False
845+
:param store_y: Store the ground truth targets in a list, defaults to False
846+
:param optimise: If True and the learner is compatible, the evaluator will
847+
use a Java native evaluation loop, defaults to True.
848+
:param restart_stream: If False, evaluation will continue from the current
849+
position in the stream, defaults to True. Not restarting the stream is
850+
useful for switching between learners or evaluators, without starting
851+
from the beginning of the stream.
852+
:return: An object containing the results of the evaluation windowed metrics,
853+
cumulative metrics, ground truth targets, and predictions.
834854
"""
835-
stream.restart()
855+
if restart_stream:
856+
stream.restart()
836857
if _is_fast_mode_compilable(stream, learner, optimise):
837-
return _prequential_evaluation_fast(stream, learner,
838-
max_instances,
839-
window_size,
840-
store_y=store_y,
841-
store_predictions=store_predictions)
858+
return _prequential_evaluation_fast(
859+
stream,
860+
learner,
861+
max_instances,
862+
window_size,
863+
store_y=store_y,
864+
store_predictions=store_predictions,
865+
)
842866

843867
predictions = None
844868
if store_predictions:
@@ -880,7 +904,7 @@ def prequential_evaluation(
880904
schema=stream.get_schema(), window_size=window_size
881905
)
882906
while stream.has_more_instances() and (
883-
max_instances is None or instancesProcessed <= max_instances
907+
max_instances is None or instancesProcessed <= max_instances
884908
):
885909
instance = stream.next_instance()
886910

@@ -933,25 +957,55 @@ def prequential_evaluation(
933957
return results
934958

935959

936-
# TODO: Include store_predictions and store_y logic
937960
def prequential_ssl_evaluation(
938-
stream,
939-
learner,
940-
max_instances=None,
941-
window_size=1000,
942-
initial_window_size=0,
943-
delay_length=0,
944-
label_probability=0.01,
945-
random_seed=1,
946-
store_predictions=False,
947-
store_y=False,
948-
optimise=True,
961+
stream: Stream,
962+
learner: Union[ClassifierSSL, Classifier],
963+
max_instances: Optional[int] = None,
964+
window_size: int = 1000,
965+
initial_window_size: int = 0,
966+
delay_length: int = 0,
967+
label_probability: float = 0.01,
968+
random_seed: int = 1,
969+
store_predictions: bool = False,
970+
store_y: bool = False,
971+
optimise: bool = True,
972+
restart_stream: bool = True,
949973
):
950-
"""
951-
If the learner is not an SSL learner, then it will be trained only on the labeled instances.
974+
"""Run and evaluate a learner on a semi-supervised stream using prequential evaluation.
975+
976+
:param stream: A data stream to evaluate the learner on. Will be restarted if
977+
``restart_stream`` is True.
978+
:param learner: The learner to evaluate. If the learner is an SSL learner,
979+
it will be trained on both labeled and unlabeled instances. If the
980+
learner is not an SSL learner, then it will be trained only on the
981+
labeled instances.
982+
:param max_instances: The number of instances to evaluate before exiting.
983+
If None, the evaluation will continue until the stream is empty.
984+
:param window_size: The size of the window used for windowed evaluation,
985+
defaults to 1000
986+
:param initial_window_size: Not implemented yet
987+
:param delay_length: If greater than zero the labeled (``label_probability``%)
988+
instances will appear as unlabeled before reappearing as labeled after
989+
``delay_length`` instances, defaults to 0
990+
:param label_probability: The proportion of instances that will be labeled,
991+
must be in the range [0, 1], defaults to 0.01
992+
:param random_seed: A random seed to define the random state that decides
993+
which instances are labeled and which are not, defaults to 1.
994+
:param store_predictions: Store the learner's prediction in a list, defaults
995+
to False
996+
:param store_y: Store the ground truth targets in a list, defaults to False
997+
:param optimise: If True and the learner is compatible, the evaluator will
998+
use a Java native evaluation loop, defaults to True.
999+
:param restart_stream: If False, evaluation will continue from the current
1000+
position in the stream, defaults to True. Not restarting the stream is
1001+
useful for switching between learners or evaluators, without starting
1002+
from the beginning of the stream.
1003+
:return: An object containing the results of the evaluation windowed metrics,
1004+
cumulative metrics, ground truth targets, and predictions.
9521005
"""
9531006

954-
stream.restart()
1007+
if restart_stream:
1008+
stream.restart()
9551009

9561010
if _is_fast_mode_compilable(stream, learner, optimise):
9571011
return _prequential_ssl_evaluation_fast(stream,

src/capymoa/evaluation/results.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@ def __init__(self,
3030
self._predictions = predictions
3131
self._other_metrics = other_metrics
3232
# attributes
33-
self.learner = learner
34-
self.stream = stream
33+
#: The name of the learner
34+
self.learner: str = learner
35+
#: The stream used to evaluate the learner
36+
self.stream: Stream = stream
37+
#: The cumulative evaluator
3538
self.cumulative = cumulative_evaluator
39+
#: The windowed evaluator
3640
self.windowed = windowed_evaluator
3741

3842
def __getitem__(self, key):

src/capymoa/stream/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._stream import Stream, Schema, ARFFStream, stream_from_file, CSVStream
1+
from ._stream import Stream, Schema, ARFFStream, stream_from_file, CSVStream, NumpyStream
22
from .PytorchStream import PytorchStream
33
from . import drift, generator, preprocessing
44

@@ -12,4 +12,5 @@
1212
"drift",
1313
"generator",
1414
"preprocessing",
15+
"NumpyStream"
1516
]

src/capymoa/stream/_stream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing
22
import warnings
3-
from typing import Dict, Optional, Sequence
3+
from typing import Dict, Optional, Sequence, Union
44

55
import numpy as np
66
from numpy.lib import recfunctions as rfn
@@ -380,7 +380,7 @@ def __init__(
380380
def has_more_instances(self):
381381
return self.arff_instances_data.numInstances() > self.current_instance_index
382382

383-
def next_instance(self) -> Instance:
383+
def next_instance(self) -> Union[LabeledInstance, RegressionInstance]:
384384
# Return None if all instances have been read already.
385385
if not self.has_more_instances():
386386
return None

tests/test_evaluation.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
from capymoa.evaluation.evaluation import prequential_evaluation_anomaly
2-
from capymoa.stream.generator import SEA
1+
from contextlib import nullcontext
2+
from itertools import product
3+
from capymoa.evaluation.evaluation import _is_fast_mode_compilable, prequential_evaluation_anomaly
4+
from capymoa.regressor import KNNRegressor
5+
from capymoa.stream.generator import SEA, HyperPlaneRegression, RandomTreeGenerator
36
from capymoa.classifier import NaiveBayes, HoeffdingTree
47
from capymoa.evaluation import (prequential_evaluation,
58
prequential_evaluation_multiple_learners,
@@ -171,4 +174,77 @@ def test_prequential_evaluation_anomaly():
171174
assert results_1st_run['windowed'].auc() == pytest.approx(
172175
results_2nd_run['windowed'].auc(), abs=0.001
173176
), f"prequential_evaluation_anomaly same synthetic stream: Expected AUC of " \
174-
f"{results_1st_run['windowed'].auc():0.3f} got {results_2nd_run['windowed'].auc(): 0.3f}"
177+
f"{results_1st_run['windowed'].auc():0.3f} got {results_2nd_run['windowed'].auc(): 0.3f}"
178+
179+
180+
181+
182+
@pytest.mark.parametrize(
183+
["restart_stream", "optimise", "regression", "evaluation"],
184+
list(
185+
product(
186+
[True, False],
187+
[True, False],
188+
[True, False],
189+
[
190+
prequential_evaluation,
191+
prequential_ssl_evaluation,
192+
],
193+
)
194+
),
195+
)
196+
def test_restart_stream_flag(restart_stream, optimise, regression, evaluation):
197+
"""Ensure that the stream is restarted when the restart_stream flag is set to True"""
198+
expect_error = False
199+
# Some configurations are not supported by some evaluation methods.
200+
# When these are eventually supported, this test will need to be updated.
201+
202+
# Create a stream and learner
203+
stream = (
204+
HyperPlaneRegression() if regression else RandomTreeGenerator(num_classes=10)
205+
)
206+
207+
# This evaluation function does not yet support regression
208+
if evaluation == prequential_ssl_evaluation and regression:
209+
expect_error = True
210+
211+
if not regression:
212+
learner = NaiveBayes(
213+
schema=stream.get_schema()
214+
) # The type of model is not important
215+
else:
216+
learner = KNNRegressor(schema=stream.get_schema())
217+
assert _is_fast_mode_compilable(
218+
stream, learner, True
219+
), "Fast mode should always be compilable for this test"
220+
221+
def _take_y(num_instances):
222+
if regression:
223+
return [stream.next_instance().y_value for _ in range(num_instances)]
224+
else:
225+
return [stream.next_instance().y_index for _ in range(num_instances)]
226+
227+
# Store targets from the stream for use in assertions later.
228+
y_stream = _take_y(20)
229+
stream.restart() # Must restart the stream to get the same instances again
230+
231+
# Consume the first 10 instances
232+
_take_y(10)
233+
with pytest.raises((RuntimeError, ValueError)) if expect_error else nullcontext():
234+
# Consume either the next 5 instances or the same 5 instances again
235+
# depending on the ``restart_stream`` flag
236+
evaluation(
237+
stream=stream,
238+
learner=learner,
239+
max_instances=5,
240+
optimise=optimise,
241+
restart_stream=restart_stream,
242+
)
243+
244+
# If the stream is restarted, the next 5 instances should be the same as those
245+
# we remembered. Otherwise, they should be different.
246+
y_remaining = _take_y(5)
247+
if restart_stream == True:
248+
assert y_remaining == y_stream[5:10]
249+
else:
250+
assert y_remaining == y_stream[15:20]

0 commit comments

Comments
 (0)