Skip to content

Commit 5935ccd

Browse files
committed
Update bases
1 parent 62ae239 commit 5935ccd

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

Util/Bases.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ def evaluate(self, x, y, metrics=None, tar=0, prefix="Acc", **kwargs):
319319
def get_2d_plot(self, x, y, padding=1, dense=200, title=None,
320320
draw_background=False, emphasize=None, extra=None, **kwargs):
321321
axis, labels = np.asarray(x).T, np.asarray(y)
322-
decision_function = lambda _xx: self.predict(_xx)
323322
nx, ny, padding = dense, dense, padding
324323
x_min, x_max = np.min(axis[0]), np.max(axis[0]) # type: float
325324
y_min, y_max = np.min(axis[1]), np.max(axis[1]) # type: float
@@ -337,7 +336,7 @@ def get_base(_nx, _ny):
337336
return _xf, _yf, np.c_[n_xf.ravel(), n_yf.ravel()]
338337

339338
xf, yf, base_matrix = get_base(nx, ny)
340-
z = decision_function(base_matrix).reshape((nx, ny))
339+
z = self.predict(base_matrix).reshape((nx, ny))
341340

342341
if labels.ndim == 1:
343342
if not self._plot_label_dic:
@@ -382,8 +381,6 @@ def visualize2d(self, x, y, padding=0.1, dense=200, title=None,
382381
axis, labels = np.asarray(x).T, np.asarray(y)
383382

384383
print("=" * 30 + "\n" + str(self))
385-
decision_function = lambda xx: self.predict(xx, **kwargs)
386-
387384
nx, ny, padding = dense, dense, padding
388385
x_min, x_max = np.min(axis[0]), np.max(axis[0])
389386
y_min, y_max = np.min(axis[1]), np.max(axis[1])
@@ -403,7 +400,7 @@ def get_base(_nx, _ny):
403400
xf, yf, base_matrix = get_base(nx, ny)
404401

405402
t = time.time()
406-
z = decision_function(base_matrix).reshape((nx, ny))
403+
z = self.predict(base_matrix, **kwargs).reshape((nx, ny))
407404
print("Decision Time: {:8.6} s".format(time.time() - t))
408405

409406
print("Drawing figures...")
@@ -456,7 +453,9 @@ def visualize3d(self, x, y, padding=0.1, dense=100, title=None,
456453
axis, labels = np.asarray(x).T, np.asarray(y)
457454

458455
print("=" * 30 + "\n" + str(self))
459-
decision_function = lambda xx: self.predict(xx, **kwargs)
456+
457+
def decision_function(xx):
458+
return self.predict(xx, **kwargs)
460459

461460
nx, ny, nz, padding = dense, dense, dense, padding
462461
x_min, x_max = np.min(axis[0]), np.max(axis[0])
@@ -728,6 +727,9 @@ def batch_training(self, x, y, batch_size, train_repeat, *args):
728727
return float((epoch_cost / train_repeat).data.numpy()[0])
729728

730729
def _predict(self, x, get_raw_results=False, **kwargs):
730+
"""
731+
:rtype: np.ndarray
732+
"""
731733
pass
732734

733735
def predict(self, x, get_raw_results=False, **kwargs):

0 commit comments

Comments
 (0)