1
1
import numpy as np
2
2
import tensorflow as tf
3
3
4
+ from NN .Basic .Optimizers import OptFactory
4
5
from NN .TF .Optimizers import OptFactory as TFOptFac
5
6
6
7
from Util .Timing import Timing
@@ -21,20 +22,29 @@ class LinearSVM(ClassifierBase):
21
22
def __init__ (self , ** kwargs ):
22
23
super (LinearSVM , self ).__init__ (** kwargs )
23
24
self ._w = self ._b = None
25
+ self ._optimizer = self ._model_parameters = None
24
26
25
27
self ._params ["c" ] = kwargs .get ("c" , 1 )
26
- self ._params ["lr" ] = kwargs .get ("lr" , 0.001 )
28
+ self ._params ["lr" ] = kwargs .get ("lr" , 0.01 )
29
+ self ._params ["optimizer" ] = kwargs .get ("optimizer" , "Adam" )
30
+ self ._params ["batch_size" ] = kwargs .get ("batch_size" , 128 )
27
31
self ._params ["epoch" ] = kwargs .get ("epoch" , 10 ** 4 )
28
- self ._params ["tol" ] = kwargs .get ("tol" , 1e-3 )
32
+ self ._params ["tol" ] = kwargs .get ("tol" , 1e-6 )
29
33
30
34
@LinearSVMTiming .timeit (level = 1 , prefix = "[API] " )
31
- def fit (self , x , y , sample_weight = None , c = None , lr = None , epoch = None , tol = None , animation_params = None ):
35
+ def fit (self , x , y , sample_weight = None , c = None , lr = None , optimizer = None ,
36
+ batch_size = None , epoch = None , tol = None , animation_params = None ):
32
37
if sample_weight is None :
33
38
sample_weight = self ._params ["sample_weight" ]
34
39
if c is None :
35
40
c = self ._params ["c" ]
36
41
if lr is None :
37
42
lr = self ._params ["lr" ]
43
+ if optimizer is None :
44
+ optimizer = self ._params ["optimizer" ]
45
+ if batch_size is None :
46
+ batch_size = self ._params ["batch_size" ]
47
+ batch_size = min (len (x ), batch_size )
38
48
if epoch is None :
39
49
epoch = self ._params ["epoch" ]
40
50
if tol is None :
@@ -47,30 +57,56 @@ def fit(self, x, y, sample_weight=None, c=None, lr=None, epoch=None, tol=None, a
47
57
sample_weight = np .asarray (sample_weight ) * len (y )
48
58
49
59
self ._w = np .zeros (x .shape [1 ])
50
- self ._b = 0
60
+ self ._b = np .zeros (1 )
61
+ self ._model_parameters = (self ._w , self ._b )
62
+ self ._optimizer = OptFactory ().get_optimizer_by_name (
63
+ optimizer , self ._model_parameters , lr , epoch
64
+ )
51
65
ims = []
66
+
67
+ train_repeat = self ._get_train_repeat (x , batch_size )
68
+ args = (c , lr , sample_weight , tol )
69
+
52
70
bar = ProgressBar (max_value = epoch , name = "LinearSVM" )
53
71
for i in range (epoch ):
54
- err = (1 - self .predict (x , get_raw_results = True ) * y ) * sample_weight
55
- indices = np .random .permutation (len (y ))
56
- idx = indices [np .argmax (err [indices ])]
57
- if err [idx ] <= tol :
72
+ if c * self .batch_training (
73
+ x , y , batch_size , train_repeat , * args
74
+ ) + np .linalg .norm (self ._w ) <= tol :
58
75
bar .terminate ()
59
76
break
60
- delta = lr * c * y [idx ] * sample_weight [idx ]
61
- self ._w *= 1 - lr
62
- self ._w += delta * x [idx ]
63
- self ._b += delta
64
77
self ._handle_animation (i , x , y , ims , animation_params , * animation_properties )
65
78
bar .update ()
66
79
self ._handle_mp4 (ims , animation_properties )
67
80
81
+ @LinearSVMTiming .timeit (level = 2 , prefix = "[Core] " )
82
+ def batch_training (self , x , y , batch_size , train_repeat , * args ):
83
+ c , lr , sample_weight , tol = args
84
+ epoch_loss = 0.
85
+ for _ in range (train_repeat ):
86
+ self ._w *= 1 - lr
87
+ if train_repeat != 1 :
88
+ batch = np .random .choice (len (x ), batch_size )
89
+ x_batch , y_batch , sample_weight_batch = x [batch ], y [batch ], sample_weight [batch ]
90
+ else :
91
+ x_batch , y_batch , sample_weight_batch = x , y , sample_weight
92
+ err = (1 - self .predict (x_batch , True ) * y_batch ) * sample_weight_batch
93
+ mask = err > 0
94
+ if not np .any (mask ):
95
+ continue
96
+ epoch_loss += np .max (err )
97
+ delta = lr * c * y_batch [mask ] * sample_weight_batch [mask ]
98
+ dw = np .mean (delta [..., None ] * x_batch [mask ], axis = 0 )
99
+ db = np .mean (delta )
100
+ self ._w += self ._optimizer .run (0 , dw )
101
+ self ._b += self ._optimizer .run (1 , db )
102
+ return epoch_loss
103
+
68
104
@LinearSVMTiming .timeit (level = 1 , prefix = "[API] " )
69
105
def predict (self , x , get_raw_results = False , ** kwargs ):
70
106
rs = np .sum (self ._w * x , axis = 1 ) + self ._b
71
- if not get_raw_results :
72
- return np . sign ( rs )
73
- return rs
107
+ if get_raw_results :
108
+ return rs
109
+ return np . sign ( rs )
74
110
75
111
76
112
class TFLinearSVM (TFClassifierBase ):
@@ -81,7 +117,7 @@ def __init__(self, **kwargs):
81
117
self ._w = self ._b = None
82
118
83
119
self ._params ["c" ] = kwargs .get ("c" , 1 )
84
- self ._params ["lr" ] = kwargs .get ("lr" , 0.001 )
120
+ self ._params ["lr" ] = kwargs .get ("lr" , 0.01 )
85
121
self ._params ["batch_size" ] = kwargs .get ("batch_size" , 128 )
86
122
self ._params ["epoch" ] = kwargs .get ("epoch" , 10 ** 4 )
87
123
self ._params ["tol" ] = kwargs .get ("tol" , 1e-3 )
0 commit comments