Skip to content

Commit 7fe5d33

Browse files
tachyonicClockhmgomes
authored andcommitted
feat: add ocl streams
1 parent 7b324e2 commit 7fe5d33

File tree

13 files changed

+943
-171
lines changed

13 files changed

+943
-171
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
toc_object_entries_show_parents = "hide"
5757
autosummary_ignore_module_all = False
5858
autosummary_generate = True
59-
autodoc_member_order = "bysource"
59+
autodoc_member_order = "groupwise"
6060
autodoc_class_signature = "separated"
6161

6262
templates_path = ["_templates"]

notebooks/03_pytorch.ipynb

Lines changed: 31 additions & 43 deletions
Large diffs are not rendered by default.

src/capymoa/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
Sensor,
3434
)
3535
from ._utils import get_download_dir
36-
from . import downloader
36+
from . import downloader, ocl
3737

3838
__all__ = [
3939
"Bike",
@@ -51,4 +51,5 @@
5151
"Sensor",
5252
"downloader",
5353
"get_download_dir",
54+
"ocl",
5455
]

src/capymoa/datasets/ocl.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
"""This module contains built-in datastream for online continual learning (OCL).
2+
3+
In OCL datastreams are irreversible sequences of examples following a
4+
non-stationary data distribution. Learners in OCL can only learn from a single
5+
pass through the datastream but are expected to perform well on any portion of
6+
the datastream.
7+
8+
Portions of the datastream where the data distribution is relatively stationary
9+
are called *tasks*.
10+
11+
A common way to construct an OCL dataset for experimentation is to groups the
12+
classes of a usual classification dataset into tasks. Known as the
13+
*class-incremental* scenario, the learner is presented with a sequence of tasks
14+
where each task contains a new subset of the classes.
15+
16+
For example :class:`SplitMNIST` splits the MNIST dataset into five tasks where each
17+
task contains two classes:
18+
19+
>>> from capymoa.datasets.ocl import SplitMNIST
20+
>>> scenario = SplitMNIST()
21+
>>> scenario.task_schedule
22+
[{1, 4}, {5, 7}, {9, 3}, {0, 8}, {2, 6}]
23+
24+
25+
To get the usual CapyMOA stream object for training:
26+
27+
>>> instance = scenario.train_stream.next_instance()
28+
>>> instance
29+
LabeledInstance(
30+
Schema(SplitMNISTTrain),
31+
x=[0. 0. 0. ... 0. 0. 0.],
32+
y_index=1,
33+
y_label='1'
34+
)
35+
36+
CapyMOA streams flatten the data into a feature vector:
37+
38+
>>> instance.x.shape
39+
(784,)
40+
41+
You can access the PyTorch datasets for each task:
42+
43+
>>> x, y = scenario.test_tasks[0][0]
44+
>>> x.shape
45+
torch.Size([1, 28, 28])
46+
>>> y
47+
1
48+
"""
49+
50+
from pathlib import Path
51+
from typing import Any, Callable, Optional, Tuple, Sequence, Set
52+
from capymoa.datasets import get_download_dir
53+
from capymoa.ocl.util.data import partition_by_schedule, class_incremental_schedule
54+
from capymoa.stream import TorchClassifyStream, Stream, ConcatStream
55+
from capymoa.instance import LabeledInstance
56+
from capymoa.stream._stream import Schema
57+
import torch
58+
from torchvision import datasets
59+
from torch.utils.data import Dataset
60+
from torch import Tensor
61+
from torchvision.transforms import ToTensor, Normalize, Compose
62+
from abc import abstractmethod, ABC
63+
64+
65+
class _BuiltInCIScenario(ABC):
66+
"""Abstract base class for built-in class incremental OCL datasets.
67+
68+
This abstract base class is for easily built-in class-incremental continual
69+
learning datasets.
70+
"""
71+
72+
train_tasks: Sequence[Dataset[Tuple[Tensor, Tensor]]]
73+
"""A sequence of PyTorch datasets representing the training tasks.
74+
75+
Use the :attr:`train_stream` instead. Unlike CapyMOA
76+
:class:`capymoa.stream.Stream` objects, :class:`torch.utils.data.Dataset`
77+
are not intended for OCL. This attribute is intended for evaluation and
78+
debugging.
79+
"""
80+
81+
test_tasks: Sequence[Dataset[Tuple[Tensor, Tensor]]]
82+
"""A sequence of PyTorch datasets containing the test tasks."""
83+
84+
train_stream: Stream[LabeledInstance]
85+
"""A stream of labeled instances for training."""
86+
87+
test_stream: Stream[LabeledInstance]
88+
"""A stream of labeled instances for testing."""
89+
90+
task_schedule: Sequence[Set[int]]
91+
"""A sequence of sets containing the classes for each task.
92+
93+
In online continual learning your learner may not have access to this
94+
attribute. It is provided for evaluation and debugging.
95+
"""
96+
97+
num_classes: int
98+
"""The number of classes in the dataset."""
99+
100+
default_task_count: int
101+
"""The default number of tasks in the dataset."""
102+
103+
mean: Sequence[float]
104+
"""The mean of the features in the dataset used for normalization."""
105+
106+
std: Sequence[float]
107+
"""The standard deviation of the features in the dataset used for normalization."""
108+
109+
default_train_transform: Callable[[Any], Tensor] = ToTensor()
110+
"""The default transform to apply to the dataset."""
111+
112+
default_test_transform: Callable[[Any], Tensor] = ToTensor()
113+
"""The default transform to apply to the dataset."""
114+
115+
schema: Schema
116+
"""A schema describing the format of the data."""
117+
118+
def __init__(
119+
self,
120+
num_tasks: Optional[int] = None,
121+
shuffle_tasks: bool = True,
122+
seed: int = 0,
123+
directory: Path = get_download_dir(),
124+
auto_download: bool = True,
125+
train_transform: Optional[Callable[[Any], Tensor]] = None,
126+
test_transform: Optional[Callable[[Any], Tensor]] = None,
127+
normalize_features: bool = False,
128+
):
129+
"""Create a new online continual learning datamodule.
130+
131+
:param num_tasks: The number of tasks to partition the dataset into,
132+
defaults to :attr:`default_task_count`.
133+
:param shuffle_tasks: Should the contents and order of the tasks be
134+
shuffled, defaults to True.
135+
:param seed: Seed for shuffling the tasks, defaults to 0.
136+
:param directory: The directory to download the dataset to, defaults to
137+
:func:`capymoa.datasets.get_download_dir`.
138+
:param auto_download: Should the dataset be automatically downloaded
139+
if it does not exist, defaults to True.
140+
:param train_transform: A transform to apply to the training dataset,
141+
defaults to :attr:`default_train_transform`.
142+
:param test_transform: A transform to apply to the test dataset,
143+
defaults to :attr:`default_test_transform`.
144+
:param normalize_features: Should the features be normalized. This
145+
normalization step is after all other transformations.
146+
"""
147+
assert self.num_classes
148+
assert self.default_task_count
149+
assert self.mean
150+
assert self.std
151+
152+
if num_tasks is None:
153+
num_tasks = self.default_task_count
154+
if train_transform is None:
155+
train_transform = self.default_train_transform
156+
if test_transform is None:
157+
test_transform = self.default_test_transform
158+
159+
if normalize_features:
160+
normalize = Normalize(self.mean, self.std)
161+
train_transform = Compose((train_transform, normalize))
162+
163+
# Set the number of tasks
164+
generator = torch.Generator().manual_seed(seed)
165+
self.task_schedule = class_incremental_schedule(
166+
self.num_classes, num_tasks, shuffle=shuffle_tasks, generator=generator
167+
)
168+
169+
# Download the dataset and partition it into tasks
170+
train_dataset = self._download_dataset(
171+
True, directory, auto_download, train_transform
172+
)
173+
test_dataset = self._download_dataset(
174+
False, directory, auto_download, test_transform
175+
)
176+
self.train_tasks = partition_by_schedule(train_dataset, self.task_schedule)
177+
self.test_tasks = partition_by_schedule(test_dataset, self.task_schedule)
178+
179+
# Create streams for training and testing
180+
dataset_prefix = self.__class__.__name__
181+
self.train_stream = _tasks_to_stream(
182+
self.train_tasks,
183+
num_classes=self.num_classes,
184+
shuffle=True,
185+
seed=seed + 1,
186+
dataset_name=f"{dataset_prefix}Train",
187+
)
188+
self.test_stream = _tasks_to_stream(
189+
self.test_tasks,
190+
num_classes=self.num_classes,
191+
shuffle=False,
192+
dataset_name=f"{dataset_prefix}Test",
193+
)
194+
self.schema = self.train_stream.get_schema()
195+
196+
@classmethod
197+
@abstractmethod
198+
def _download_dataset(
199+
self,
200+
train: bool,
201+
directory: Path,
202+
auto_download: bool,
203+
transform: Optional[Any],
204+
) -> Dataset[Tuple[Tensor, Tensor]]:
205+
pass
206+
207+
208+
def _tasks_to_stream(
209+
tasks: Sequence[Dataset[Tuple[Tensor, Tensor]]],
210+
num_classes: int,
211+
shuffle: bool = False,
212+
seed: int = 0,
213+
class_names: Optional[Sequence[str]] = None,
214+
dataset_name: str = "OnlineContinualLearningDatastream",
215+
) -> Stream[LabeledInstance]:
216+
"""Convert a sequence of tasks into a stream.
217+
218+
:param tasks: A sequence of PyTorch datasets representing tasks.
219+
:param num_classes: The number of classes in the dataset
220+
:param shuffle: Should the tasks be shuffled, defaults to False
221+
:param shuffle_seed: Seed for shuffling, defaults to 0
222+
:param class_names: The names of the classes, defaults to None
223+
:param dataset_name: The name of the dataset, defaults to
224+
"OnlineContinualLearningDatastream"
225+
:return: A stream of labeled instances for classification.
226+
"""
227+
streams = [
228+
TorchClassifyStream(
229+
task,
230+
num_classes=num_classes,
231+
shuffle=shuffle,
232+
shuffle_seed=seed,
233+
class_names=class_names,
234+
dataset_name=dataset_name,
235+
)
236+
for task in tasks
237+
]
238+
return ConcatStream(streams)
239+
240+
241+
class SplitMNIST(_BuiltInCIScenario):
242+
"""Split MNIST dataset for online class incremental learning.
243+
244+
**References:**
245+
246+
#. LeCun, Y., Cortes, C., & Burges, C. (2010). MNIST handwritten digit
247+
database. ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist
248+
"""
249+
250+
num_classes = 10
251+
default_task_count = 5
252+
mean = [0.1307]
253+
std = [0.3081]
254+
255+
@classmethod
256+
def _download_dataset(
257+
self,
258+
train: bool,
259+
directory: Path,
260+
auto_download: bool,
261+
transform: Optional[Any],
262+
) -> Dataset[Tuple[Tensor, Tensor]]:
263+
return datasets.MNIST(
264+
directory,
265+
train=train,
266+
download=auto_download,
267+
transform=transform,
268+
)
269+
270+
271+
class SplitFashionMNIST(_BuiltInCIScenario):
272+
"""Split Fashion MNIST dataset for online class incremental learning.
273+
274+
**References:**
275+
276+
#. Xiao, H., Rasul, K., & Vollgraf, R. (2017, August 28). Fashion-MNIST:
277+
a Novel Image Dataset for Benchmarking Machine Learning Algorithms.
278+
"""
279+
280+
num_classes = 10
281+
default_task_count = 5
282+
mean = [0.286]
283+
std = [0.353]
284+
285+
@classmethod
286+
def _download_dataset(
287+
self,
288+
train: bool,
289+
directory: Path,
290+
auto_download: bool,
291+
transform: Optional[Any],
292+
) -> Dataset[Tuple[Tensor, Tensor]]:
293+
return datasets.FashionMNIST(
294+
directory,
295+
train=train,
296+
download=auto_download,
297+
transform=transform,
298+
)
299+
300+
301+
class SplitCIFAR10(_BuiltInCIScenario):
302+
"""Split CIFAR-10 dataset for online class incremental learning.
303+
304+
**References:**
305+
306+
#. Krizhevsky, A. (2009). Learning Multiple Layers of Features from Tiny
307+
Images.
308+
"""
309+
310+
num_classes = 10
311+
default_task_count = 5
312+
mean = [0.491, 0.482, 0.447]
313+
std = [0.247, 0.243, 0.262]
314+
315+
@classmethod
316+
def _download_dataset(
317+
self,
318+
train: bool,
319+
directory: Path,
320+
auto_download: bool,
321+
transform: Optional[Any],
322+
) -> Dataset[Tuple[Tensor, Tensor]]:
323+
return datasets.CIFAR10(
324+
directory,
325+
train=train,
326+
download=auto_download,
327+
transform=transform,
328+
)
329+
330+
331+
class SplitCIFAR100(_BuiltInCIScenario):
332+
"""Split CIFAR-100 dataset for online class incremental learning.
333+
334+
**References:**
335+
336+
#. Krizhevsky, A. (2009). Learning Multiple Layers of Features from Tiny
337+
Images.
338+
"""
339+
340+
num_classes = 100
341+
default_task_count = 10
342+
mean = [0.507, 0.487, 0.441]
343+
std = [0.267, 0.256, 0.276]
344+
345+
@classmethod
346+
def _download_dataset(
347+
self,
348+
train: bool,
349+
directory: Path,
350+
auto_download: bool,
351+
transform: Optional[Any],
352+
) -> Dataset[Tuple[Tensor, Tensor]]:
353+
return datasets.CIFAR100(
354+
directory,
355+
train=train,
356+
download=auto_download,
357+
transform=transform,
358+
)

src/capymoa/evaluation/evaluation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ def _is_fast_mode_compilable(stream: Stream, learner, optimise=True) -> bool:
5050
return False
5151

5252
"""Check if the stream is compatible with the efficient loops in MOA."""
53-
is_moa_stream = stream.moa_stream is not None and isinstance(
54-
stream.moa_stream, InstanceStream
55-
)
53+
is_moa_stream = isinstance(stream.get_moa_stream(), InstanceStream)
5654
is_moa_learner = hasattr(learner, "moa_learner") and learner.moa_learner is not None
5755

5856
return is_moa_stream and is_moa_learner and optimise

0 commit comments

Comments
 (0)