@@ -319,7 +319,6 @@ def evaluate(self, x, y, metrics=None, tar=0, prefix="Acc", **kwargs):
319
319
def get_2d_plot (self , x , y , padding = 1 , dense = 200 , title = None ,
320
320
draw_background = False , emphasize = None , extra = None , ** kwargs ):
321
321
axis , labels = np .asarray (x ).T , np .asarray (y )
322
- decision_function = lambda _xx : self .predict (_xx )
323
322
nx , ny , padding = dense , dense , padding
324
323
x_min , x_max = np .min (axis [0 ]), np .max (axis [0 ]) # type: float
325
324
y_min , y_max = np .min (axis [1 ]), np .max (axis [1 ]) # type: float
@@ -337,7 +336,7 @@ def get_base(_nx, _ny):
337
336
return _xf , _yf , np .c_ [n_xf .ravel (), n_yf .ravel ()]
338
337
339
338
xf , yf , base_matrix = get_base (nx , ny )
340
- z = decision_function (base_matrix ).reshape ((nx , ny ))
339
+ z = self . predict (base_matrix ).reshape ((nx , ny ))
341
340
342
341
if labels .ndim == 1 :
343
342
if not self ._plot_label_dic :
@@ -382,8 +381,6 @@ def visualize2d(self, x, y, padding=0.1, dense=200, title=None,
382
381
axis , labels = np .asarray (x ).T , np .asarray (y )
383
382
384
383
print ("=" * 30 + "\n " + str (self ))
385
- decision_function = lambda xx : self .predict (xx , ** kwargs )
386
-
387
384
nx , ny , padding = dense , dense , padding
388
385
x_min , x_max = np .min (axis [0 ]), np .max (axis [0 ])
389
386
y_min , y_max = np .min (axis [1 ]), np .max (axis [1 ])
@@ -403,7 +400,7 @@ def get_base(_nx, _ny):
403
400
xf , yf , base_matrix = get_base (nx , ny )
404
401
405
402
t = time .time ()
406
- z = decision_function (base_matrix ).reshape ((nx , ny ))
403
+ z = self . predict (base_matrix , ** kwargs ).reshape ((nx , ny ))
407
404
print ("Decision Time: {:8.6} s" .format (time .time () - t ))
408
405
409
406
print ("Drawing figures..." )
@@ -456,7 +453,9 @@ def visualize3d(self, x, y, padding=0.1, dense=100, title=None,
456
453
axis , labels = np .asarray (x ).T , np .asarray (y )
457
454
458
455
print ("=" * 30 + "\n " + str (self ))
459
- decision_function = lambda xx : self .predict (xx , ** kwargs )
456
+
457
+ def decision_function (xx ):
458
+ return self .predict (xx , ** kwargs )
460
459
461
460
nx , ny , nz , padding = dense , dense , dense , padding
462
461
x_min , x_max = np .min (axis [0 ]), np .max (axis [0 ])
@@ -728,6 +727,9 @@ def batch_training(self, x, y, batch_size, train_repeat, *args):
728
727
return float ((epoch_cost / train_repeat ).data .numpy ()[0 ])
729
728
730
729
def _predict (self , x , get_raw_results = False , ** kwargs ):
730
+ """
731
+ :rtype: np.ndarray
732
+ """
731
733
pass
732
734
733
735
def predict (self , x , get_raw_results = False , ** kwargs ):
0 commit comments