Skip to content

Commit 01aaa83

Browse files
committed
Update SVM
1 parent a2814f1 commit 01aaa83

File tree

3 files changed

+57
-21
lines changed

3 files changed

+57
-21
lines changed

e_SVM/LinearSVM.py

+52-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import tensorflow as tf
33

4+
from NN.Basic.Optimizers import OptFactory
45
from NN.TF.Optimizers import OptFactory as TFOptFac
56

67
from Util.Timing import Timing
@@ -21,20 +22,29 @@ class LinearSVM(ClassifierBase):
2122
def __init__(self, **kwargs):
2223
super(LinearSVM, self).__init__(**kwargs)
2324
self._w = self._b = None
25+
self._optimizer = self._model_parameters = None
2426

2527
self._params["c"] = kwargs.get("c", 1)
26-
self._params["lr"] = kwargs.get("lr", 0.001)
28+
self._params["lr"] = kwargs.get("lr", 0.01)
29+
self._params["optimizer"] = kwargs.get("optimizer", "Adam")
30+
self._params["batch_size"] = kwargs.get("batch_size", 128)
2731
self._params["epoch"] = kwargs.get("epoch", 10 ** 4)
28-
self._params["tol"] = kwargs.get("tol", 1e-3)
32+
self._params["tol"] = kwargs.get("tol", 1e-6)
2933

3034
@LinearSVMTiming.timeit(level=1, prefix="[API] ")
31-
def fit(self, x, y, sample_weight=None, c=None, lr=None, epoch=None, tol=None, animation_params=None):
35+
def fit(self, x, y, sample_weight=None, c=None, lr=None, optimizer=None,
36+
batch_size=None, epoch=None, tol=None, animation_params=None):
3237
if sample_weight is None:
3338
sample_weight = self._params["sample_weight"]
3439
if c is None:
3540
c = self._params["c"]
3641
if lr is None:
3742
lr = self._params["lr"]
43+
if optimizer is None:
44+
optimizer = self._params["optimizer"]
45+
if batch_size is None:
46+
batch_size = self._params["batch_size"]
47+
batch_size = min(len(x), batch_size)
3848
if epoch is None:
3949
epoch = self._params["epoch"]
4050
if tol is None:
@@ -47,30 +57,56 @@ def fit(self, x, y, sample_weight=None, c=None, lr=None, epoch=None, tol=None, a
4757
sample_weight = np.asarray(sample_weight) * len(y)
4858

4959
self._w = np.zeros(x.shape[1])
50-
self._b = 0
60+
self._b = np.zeros(1)
61+
self._model_parameters = (self._w, self._b)
62+
self._optimizer = OptFactory().get_optimizer_by_name(
63+
optimizer, self._model_parameters, lr, epoch
64+
)
5165
ims = []
66+
67+
train_repeat = self._get_train_repeat(x, batch_size)
68+
args = (c, lr, sample_weight, tol)
69+
5270
bar = ProgressBar(max_value=epoch, name="LinearSVM")
5371
for i in range(epoch):
54-
err = (1 - self.predict(x, get_raw_results=True) * y) * sample_weight
55-
indices = np.random.permutation(len(y))
56-
idx = indices[np.argmax(err[indices])]
57-
if err[idx] <= tol:
72+
if c * self.batch_training(
73+
x, y, batch_size, train_repeat, *args
74+
) + np.linalg.norm(self._w) <= tol:
5875
bar.terminate()
5976
break
60-
delta = lr * c * y[idx] * sample_weight[idx]
61-
self._w *= 1 - lr
62-
self._w += delta * x[idx]
63-
self._b += delta
6477
self._handle_animation(i, x, y, ims, animation_params, *animation_properties)
6578
bar.update()
6679
self._handle_mp4(ims, animation_properties)
6780

81+
@LinearSVMTiming.timeit(level=2, prefix="[Core] ")
82+
def batch_training(self, x, y, batch_size, train_repeat, *args):
83+
c, lr, sample_weight, tol = args
84+
epoch_loss = 0.
85+
for _ in range(train_repeat):
86+
self._w *= 1 - lr
87+
if train_repeat != 1:
88+
batch = np.random.choice(len(x), batch_size)
89+
x_batch, y_batch, sample_weight_batch = x[batch], y[batch], sample_weight[batch]
90+
else:
91+
x_batch, y_batch, sample_weight_batch = x, y, sample_weight
92+
err = (1 - self.predict(x_batch, True) * y_batch) * sample_weight_batch
93+
mask = err > 0
94+
if not np.any(mask):
95+
continue
96+
epoch_loss += np.max(err)
97+
delta = lr * c * y_batch[mask] * sample_weight_batch[mask]
98+
dw = np.mean(delta[..., None] * x_batch[mask], axis=0)
99+
db = np.mean(delta)
100+
self._w += self._optimizer.run(0, dw)
101+
self._b += self._optimizer.run(1, db)
102+
return epoch_loss
103+
68104
@LinearSVMTiming.timeit(level=1, prefix="[API] ")
69105
def predict(self, x, get_raw_results=False, **kwargs):
70106
rs = np.sum(self._w * x, axis=1) + self._b
71-
if not get_raw_results:
72-
return np.sign(rs)
73-
return rs
107+
if get_raw_results:
108+
return rs
109+
return np.sign(rs)
74110

75111

76112
class TFLinearSVM(TFClassifierBase):
@@ -81,7 +117,7 @@ def __init__(self, **kwargs):
81117
self._w = self._b = None
82118

83119
self._params["c"] = kwargs.get("c", 1)
84-
self._params["lr"] = kwargs.get("lr", 0.001)
120+
self._params["lr"] = kwargs.get("lr", 0.01)
85121
self._params["batch_size"] = kwargs.get("batch_size", 128)
86122
self._params["epoch"] = kwargs.get("epoch", 10 ** 4)
87123
self._params["tol"] = kwargs.get("tol", 1e-3)

e_SVM/Perceptron.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def fit(self, x, y, sample_weight=None, lr=None, epoch=None, animation_params=No
3636
ims = []
3737
bar = ProgressBar(max_value=epoch, name="Perceptron")
3838
for i in range(epoch):
39-
y_pred = self.predict(x, True)
40-
idx = np.argmax(np.maximum(0, -y_pred * y) * sample_weight)
41-
if y_pred[idx] * y[idx] > 0:
39+
err = -y * self.predict(x, True) * sample_weight
40+
idx = np.argmax(err)
41+
if err[idx] < 0:
4242
bar.terminate()
4343
break
4444
delta = lr * y[idx] * sample_weight[idx]

e_SVM/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ Implemented `Tensorflow` & `PyTorch` backend for `LinearSVM` & `SVM`
1111
![Perceptron on Two Clusters](https://cdn.rawgit.com/carefree0910/Resources/d269faeb/Backgrounds/Perceptron.gif)
1212

1313
### LinearSVM
14-
![LinearSVM on Two Clusters](https://cdn.rawgit.com/carefree0910/Resources/d269faeb/Lines/LinearSVM.gif)
14+
![LinearSVM on Two Clusters](https://cdn.rawgit.com/carefree0910/Resources/83441596/Lines/LinearSVM.gif)
1515

16-
![LinearSVM on Two Clusters](https://cdn.rawgit.com/carefree0910/Resources/d269faeb/Backgrounds/LinearSVM.gif)
16+
![LinearSVM on Two Clusters](https://cdn.rawgit.com/carefree0910/Resources/83441596/Backgrounds/LinearSVM.gif)
1717

1818
### TFLinearSVM
1919
![TFLinearSVM on Two Clusters](https://cdn.rawgit.com/carefree0910/Resources/d269faeb/Lines/TFLinearSVM.gif)

0 commit comments

Comments
 (0)