Skip to content

Commit e0d3052

Browse files
committed
Update e_SVM
1 parent ef362be commit e0d3052

File tree

4 files changed

+39
-0
lines changed

4 files changed

+39
-0
lines changed

e_SVM/KP.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _get_grads(self, x_batch, y_batch, y_pred, sample_weight_batch, *args):
5858
]
5959
return np.sum(err[mask])
6060

61+
6162
if __name__ == '__main__':
6263
# xs, ys = DataUtil.gen_two_clusters(center=5, dis=1, scale=2, one_hot=False)
6364
xs, ys = DataUtil.gen_spiral(20, 4, 2, 2, one_hot=False)

e_SVM/Perceptron.py

+36
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,39 @@ def predict(self, x, get_raw_results=False, **kwargs):
6060
if get_raw_results:
6161
return rs
6262
return np.sign(rs).astype(np.float32)
63+
64+
65+
class Perceptron2(Perceptron):
66+
def fit(self, x, y, sample_weight=None, lr=None, epoch=None, animation_params=None):
67+
if sample_weight is None:
68+
sample_weight = self._params["sample_weight"]
69+
if lr is None:
70+
lr = self._params["lr"]
71+
if epoch is None:
72+
epoch = self._params["epoch"]
73+
*animation_properties, animation_params = self._get_animation_params(animation_params)
74+
75+
x, y = np.atleast_2d(x), np.asarray(y)
76+
if sample_weight is None:
77+
sample_weight = np.ones(len(y))
78+
else:
79+
sample_weight = np.asarray(sample_weight) * len(y)
80+
81+
self._w = np.random.random(x.shape[1])
82+
self._b = 0.
83+
ims = []
84+
bar = ProgressBar(max_value=epoch, name="Perceptron")
85+
for i in range(epoch):
86+
y_pred = self.predict(x, True)
87+
err = -y * y_pred * sample_weight
88+
idx = np.argmax(err)
89+
if err[idx] < 0:
90+
bar.terminate()
91+
break
92+
w_norm = np.linalg.norm(self._w)
93+
delta = lr * y[idx] * sample_weight[idx] / w_norm
94+
self._w += delta * (x[idx] - y_pred[idx] * self._w / w_norm ** 2)
95+
self._b += delta
96+
self._handle_animation(i, x, y, ims, animation_params, *animation_properties)
97+
bar.update()
98+
self._handle_mp4(ims, animation_properties)

e_SVM/TestLinear.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,6 @@ def main():
4343

4444
perceptron.show_timing_log()
4545

46+
4647
if __name__ == '__main__':
4748
main()

e_SVM/TestSVM.py

+1
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,6 @@ def main():
110110

111111
svm.show_timing_log()
112112

113+
113114
if __name__ == '__main__':
114115
main()

0 commit comments

Comments
 (0)