Skip to content

Commit 710a977

Browse files
committed
Update _Dist/NeuralNetworks
1 parent 15bda06 commit 710a977

File tree

2 files changed

+3
-17
lines changed

2 files changed

+3
-17
lines changed

_Dist/NeuralNetworks/Base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def transform_arr(arr):
711711
return self
712712

713713
def visualize2d(self, x, y, padding=0.1, dense=200, title=None,
714-
show_org=False, draw_background=True, emphasize=None, extra=None):
714+
scatter=True, show_org=False, draw_background=True, emphasize=None, extra=None):
715715
axis, labels = np.asarray(x).T, np.asarray(y)
716716

717717
print("=" * 30 + "\n" + str(self))
@@ -764,7 +764,8 @@ def get_base(_nx, _ny):
764764
plt.pcolormesh(xy_xf, xy_yf, z, cmap=plt.cm.Pastel1)
765765
else:
766766
plt.contour(xf, yf, z, c='k-', levels=[0])
767-
plt.scatter(axis[0], axis[1], c=colors)
767+
if scatter:
768+
plt.scatter(axis[0], axis[1], c=colors)
768769
if emphasize is not None:
769770
indices = np.array([False] * len(axis[0]))
770771
indices[np.asarray(emphasize)] = True

_Dist/NeuralNetworks/d_Traditional2NN/Toolbox.py

-15
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,3 @@ def _transform(self):
172172
w2 *= max_route_length
173173
self._transform_ws = [w1, w2, w3]
174174
self._transform_bs = [b]
175-
176-
177-
if __name__ == '__main__':
178-
from Util.Util import DataUtil
179-
180-
centers = (1, 1)
181-
slopes = (0.5, -2)
182-
x, y = DataUtil.gen_x_set(1000, centers, slopes, one_hot=False)
183-
x_test, y_test = DataUtil.gen_x_set(100, centers, slopes, one_hot=False)
184-
185-
(x, y), (x_test, y_test), *_ = DataUtil.get_dataset("cifar10", "../../../_Data/cifar10.txt", 400, quantized=True)
186-
187-
DT2NN().fit(x, y, x_test, y_test).scatter2d(
188-
x, y, padding=0.1, title="X set"
189-
).visualize2d(x, y).visualize2d(x_test, y_test, padding=1)

0 commit comments

Comments
 (0)