Skip to content

Commit f966e32

Browse files
authored
feat: add shrub ensembles for online classification
Buschjäger, S., Hess, S., & Morik, K. J. (2022, June). Shrub ensembles for online classification. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 36, No. 6, pp. 6123-6131).
1 parent ac7e578 commit f966e32

File tree

9 files changed

+593
-7
lines changed

9 files changed

+593
-7
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
]
3939

4040
nitpick_ignore_regex = [
41+
("py:class", r".*\._[\w_]*"), # Ignore private classes from nitpick errors
4142
("py:class", r"abc\..*"),
4243
("py:class", r"com\..*"),
4344
("py:class", r"java\..*"),

src/capymoa/classifier/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ._dynamic_weighted_majority import DynamicWeightedMajority
2020
from ._csmote import CSMOTE
2121
from ._weightedknn import WeightedkNN
22+
from ._shrubs_classifier import ShrubsClassifier
2223

2324
__all__ = [
2425
"AdaptiveRandomForestClassifier",
@@ -41,5 +42,6 @@
4142
"SAMkNN",
4243
"DynamicWeightedMajority",
4344
"CSMOTE",
44-
"WeightedkNN"
45+
"WeightedkNN",
46+
"ShrubsClassifier",
4547
]
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from __future__ import annotations
2+
from typing import Literal
3+
4+
import numpy as np
5+
6+
from capymoa.base import Classifier
7+
from capymoa.stream._stream import Schema
8+
from capymoa.classifier._shrubs_ensemble import _ShrubEnsembles
9+
from sklearn.tree import DecisionTreeClassifier
10+
11+
12+
class ShrubsClassifier(_ShrubEnsembles, Classifier):
13+
"""ShrubsClassifier
14+
15+
This class implements the ShrubEnsembles algorithm for classification, which is
16+
an ensemble classifier that continuously adds decision trees to the ensemble by training new trees over a sliding window while pruning unnecessary trees away using proximal (stochastic) gradient descent, hence allowing for adaptation to concept drift.
17+
18+
Reference:
19+
20+
`Shrub Ensembles for Online Classification
21+
Sebastian Buschjäger, Sibylle Hess, and Katharina Morik
22+
In Proceedings of the Thirty-Sixth AAAI Conference on Artificial Intelligence (AAAI-22), Jan 2022.
23+
<https://aaai.org/papers/06123-shrub-ensembles-for-online-classification/>`_
24+
25+
Example usage:
26+
27+
>>> from capymoa.datasets import ElectricityTiny
28+
>>> from capymoa.classifier import ShrubsClassifier
29+
>>> from capymoa.evaluation import prequential_evaluation
30+
>>> stream = ElectricityTiny()
31+
>>> schema = stream.get_schema()
32+
>>> learner = ShrubsClassifier(schema)
33+
>>> results = prequential_evaluation(stream, learner, max_instances=1000)
34+
>>> results["cumulative"].accuracy()
35+
85.5...
36+
37+
"""
38+
39+
def __init__(
40+
self,
41+
schema: Schema,
42+
loss: Literal["mse", "ce", "h2"] = "ce",
43+
step_size: float | Literal["adaptive"] = "adaptive",
44+
ensemble_regularizer: Literal["hard-L0", "L0", "L1", "none"] = "hard-L0",
45+
l_ensemble_reg: float | int = 32,
46+
l_l2_reg: float = 0,
47+
l_tree_reg: float = 0,
48+
normalize_weights: bool = True,
49+
burnin_steps: int = 5,
50+
update_leaves: bool = False,
51+
batch_size: int = 32,
52+
sk_dt: DecisionTreeClassifier = DecisionTreeClassifier(
53+
splitter="best", criterion="gini", max_depth=None, random_state=1234
54+
),
55+
):
56+
57+
"""Initializes the ShrubEnsemble classifier with the given parameters.
58+
59+
:param loss: The loss function to be used. Supported values are ``"mse"``,
60+
``"ce"``, and ``"h2"``.
61+
:param step_size: The step size (i.e. learning rate of SGD) for updating
62+
the model. Can be a float or "adaptive". Adaptive reduces the step
63+
size with more estimators, i.e. sets it to ``1.0 / (n_estimators +
64+
1.0)``
65+
:param ensemble_regularizer: The regularizer for the weights of the
66+
ensemble. Supported values are:
67+
68+
* ``hard-L0``: L0 regularization via the prox-operator.
69+
* ``L0``: L0 regularization via projection.
70+
* ``L1``: L1 regularization via projection.
71+
* ``none``: No regularization.
72+
73+
Projection can be viewed as a softer regularization that drives the
74+
weights of each member towards 0, whereas ``hard-l0`` limits the
75+
number of trees in the entire ensemble.
76+
:param l_ensemble_reg: The regularization strength. Depending on the
77+
value of ``ensemble_regularizer``, this parameter has different
78+
meanings:
79+
80+
* ``hard-L0``: then this parameter represent the total number of
81+
trees in the ensembles.
82+
* ``L0`` or ``L1``: then this parameter is the regularization
83+
strength. In these cases the number of trees grow over time and
84+
only trees that do not contribute to the ensemble will be
85+
removed.
86+
* ``none``: then this parameter is ignored.
87+
:param l_l2_reg: The L2 regularization strength of the weights of each
88+
tree.
89+
:param l_tree_reg: The regularization parameter for individual trees.
90+
Must be greater than or equal to 0. ``l_tree_reg`` controls the
91+
number of (overly) large trees in the ensemble by punishing the
92+
weights of each tree. Formally, the number of nodes of each tree is
93+
used as an additional regularizer.
94+
:param normalize_weights: Whether to normalize the weights of the
95+
ensemble, i.e. the weight sum to 1.
96+
:param burnin_steps: The number of burn-in steps before updating the
97+
model, i.e. the number of SGD steps to be take per each call of
98+
train
99+
:param update_leaves: Whether to update the leaves of the trees as well
100+
using SGD.
101+
:param batch_size: The batch size for training each individual tree.
102+
Internally, a sliding window is stored. Must be greater than or
103+
equal to 1.
104+
:param sk_dt: Base object which is used to clone any new decision trees
105+
from. Note, that if you set random_state to an integer the exact
106+
same clone is used for any DT object
107+
"""
108+
109+
Classifier.__init__(self, schema, sk_dt.random_state)
110+
_ShrubEnsembles.__init__(self, schema, loss, step_size, ensemble_regularizer, l_ensemble_reg, l_l2_reg, l_tree_reg, normalize_weights, burnin_steps, update_leaves, batch_size, sk_dt)
111+
112+
def __str__(self):
113+
return str("ShrubsClassifier")
114+
115+
def _individual_proba(self, X):
116+
# assert self.estimators_ is not None, "Call fit before calling predict_proba!"
117+
118+
if len(X.shape) < 2:
119+
all_proba = np.zeros(shape=(len(self.estimators_), 1, self.n_classes_), dtype=np.float32)
120+
else:
121+
all_proba = np.zeros(shape=(len(self.estimators_), X.shape[0], self.n_classes_), dtype=np.float32)
122+
123+
for i, e in enumerate(self.estimators_):
124+
if len(X.shape) < 2:
125+
all_proba[i, 1, e.classes_.astype(int)] += e.predict_proba(X[np.newaxis,:])
126+
else:
127+
proba = e.predict_proba(X)
128+
# Numpy seems to do some weird stuff when it comes to advanced indexing.
129+
# Basically, due to e.classes_.astype(int) the last and second-to-last dimensions of all_proba
130+
# are swapped when doing all_proba[i, :, e.classes_.astype(int)]. Hence, we would also need to swap
131+
# the shapes of proba to match this correctly. Alternatively, we use a simpler form of indexing as below.
132+
# Both should work fine
133+
# all_proba[i, :, e.classes_.astype(int)] += proba.T
134+
all_proba[i, :, :][:, e.classes_.astype(int)] += proba
135+
136+
return all_proba
137+
138+
def predict_proba(self, instance):
139+
if (len(self.estimators_)) == 0:
140+
return 1.0 / self.n_classes_ * np.ones(self.n_classes_)
141+
else:
142+
all_proba = self._individual_proba(np.array([instance.x]))
143+
scaled_prob = sum([w * p for w,p in zip(all_proba, self.estimator_weights_)])
144+
combined_proba = np.sum(scaled_prob, axis=0)
145+
return combined_proba
146+
147+
def predict(self, instance):
148+
# Return the index of the class with the highest probability
149+
return self.predict_proba(instance).argmax(axis=0)

0 commit comments

Comments
 (0)