21
21
svm = SVM (** copy .deepcopy (base_params ))
22
22
nn = Basic (** copy .deepcopy (base_params ))
23
23
linear_svm = LinearSVM (** copy .deepcopy (base_params ))
24
- train_set , test_set = DataUtil .gen_noisy_linear (1000 , 2 , 2 , one_hot = False )
24
+ train_set , cv_set , test_set = DataUtil .gen_special_linear (1000 , 2 , 2 , 2 , one_hot = False )
25
25
26
26
27
27
class TestSVM (unittest .TestCase ):
28
28
def test_00_train (self ):
29
29
self .assertIsInstance (
30
- svm .fit (* train_set , * test_set , verbose = 0 ), SVM ,
30
+ svm .fit (* train_set , * cv_set , verbose = 0 ), SVM ,
31
31
msg = "Train failed"
32
32
)
33
33
34
34
def test_01_predict (self ):
35
35
self .assertIs (svm .predict (train_set [0 ]).dtype , np .dtype ("float32" ), "Predict failed" )
36
- self .assertIs (svm .predict_classes (test_set [0 ]).dtype , np .dtype ("int32" ), "Predict classes failed" )
36
+ self .assertIs (svm .predict_classes (cv_set [0 ]).dtype , np .dtype ("int32" ), "Predict classes failed" )
37
37
38
38
def test_02_evaluate (self ):
39
- self .assertEqual (len (svm .evaluate (* train_set , * test_set )), 3 , "Evaluation failed" )
39
+ self .assertEqual (len (svm .evaluate (* train_set , * cv_set )), 3 , "Evaluation failed" )
40
40
41
41
def test_03_save (self ):
42
42
self .assertIsInstance (svm .save (), SVM , msg = "Save failed" )
@@ -48,14 +48,14 @@ def test_04_load(self):
48
48
49
49
def test_05_re_predict (self ):
50
50
self .assertIs (svm .predict (train_set [0 ]).dtype , np .dtype ("float32" ), "Re-Predict failed" )
51
- self .assertIs (svm .predict_classes (test_set [0 ]).dtype , np .dtype ("int32" ), "Re-Predict classes failed" )
51
+ self .assertIs (svm .predict_classes (cv_set [0 ]).dtype , np .dtype ("int32" ), "Re-Predict classes failed" )
52
52
53
53
def test_06_re_evaluate (self ):
54
- self .assertEqual (len (svm .evaluate (* train_set , * test_set )), 3 , "Re-Evaluation failed" )
54
+ self .assertEqual (len (svm .evaluate (* train_set , * cv_set )), 3 , "Re-Evaluation failed" )
55
55
56
56
def test_07_re_train (self ):
57
57
self .assertIsInstance (
58
- svm .fit (* train_set , * test_set , verbose = 0 ), SVM ,
58
+ svm .fit (* train_set , * cv_set , verbose = 0 ), SVM ,
59
59
msg = "Re-Train failed"
60
60
)
61
61
@@ -66,16 +66,16 @@ def test_99_clear_cache(self):
66
66
class TestBasicNN (unittest .TestCase ):
67
67
def test_00_train (self ):
68
68
self .assertIsInstance (
69
- nn .fit (* train_set , * test_set , verbose = 0 ), Basic ,
69
+ nn .fit (* train_set , * cv_set , verbose = 0 ), Basic ,
70
70
msg = "Train failed"
71
71
)
72
72
73
73
def test_01_predict (self ):
74
74
self .assertIs (nn .predict (train_set [0 ]).dtype , np .dtype ("float32" ), "Predict failed" )
75
- self .assertIs (nn .predict_classes (test_set [0 ]).dtype , np .dtype ("int32" ), "Predict classes failed" )
75
+ self .assertIs (nn .predict_classes (cv_set [0 ]).dtype , np .dtype ("int32" ), "Predict classes failed" )
76
76
77
77
def test_02_evaluate (self ):
78
- self .assertEqual (len (nn .evaluate (* train_set , * test_set )), 3 , "Evaluation failed" )
78
+ self .assertEqual (len (nn .evaluate (* train_set , * cv_set )), 3 , "Evaluation failed" )
79
79
80
80
def test_03_save (self ):
81
81
self .assertIsInstance (nn .save (), Basic , msg = "Save failed" )
@@ -87,14 +87,14 @@ def test_04_load(self):
87
87
88
88
def test_05_re_predict (self ):
89
89
self .assertIs (nn .predict (train_set [0 ]).dtype , np .dtype ("float32" ), "Re-Predict failed" )
90
- self .assertIs (nn .predict_classes (test_set [0 ]).dtype , np .dtype ("int32" ), "Re-Predict classes failed" )
90
+ self .assertIs (nn .predict_classes (cv_set [0 ]).dtype , np .dtype ("int32" ), "Re-Predict classes failed" )
91
91
92
92
def test_06_re_evaluate (self ):
93
- self .assertEqual (len (nn .evaluate (* train_set , * test_set )), 3 , "Re-Evaluation failed" )
93
+ self .assertEqual (len (nn .evaluate (* train_set , * cv_set )), 3 , "Re-Evaluation failed" )
94
94
95
95
def test_07_re_train (self ):
96
96
self .assertIsInstance (
97
- nn .fit (* train_set , * test_set , verbose = 0 ), Basic ,
97
+ nn .fit (* train_set , * cv_set , verbose = 0 ), Basic ,
98
98
msg = "Re-Train failed"
99
99
)
100
100
@@ -105,16 +105,17 @@ def test_99_clear_cache(self):
105
105
class TestLinearSVM (unittest .TestCase ):
106
106
def test_00_train (self ):
107
107
self .assertIsInstance (
108
- linear_svm .fit (* train_set , * test_set , verbose = 0 ), LinearSVM ,
108
+ linear_svm .fit (* train_set , * cv_set , verbose = 0 ), LinearSVM ,
109
109
msg = "Train failed"
110
110
)
111
111
112
112
def test_01_predict (self ):
113
113
self .assertIs (linear_svm .predict (train_set [0 ]).dtype , np .dtype ("float32" ), "Predict failed" )
114
+ self .assertIs (linear_svm .predict_classes (cv_set [0 ]).dtype , np .dtype ("int32" ), "Predict classes failed" )
114
115
self .assertIs (linear_svm .predict_classes (test_set [0 ]).dtype , np .dtype ("int32" ), "Predict classes failed" )
115
116
116
117
def test_02_evaluate (self ):
117
- self .assertEqual (len (linear_svm .evaluate (* train_set , * test_set )), 3 , "Evaluation failed" )
118
+ self .assertEqual (len (linear_svm .evaluate (* train_set , * cv_set , * test_set )), 3 , "Evaluation failed" )
118
119
119
120
def test_03_save (self ):
120
121
self .assertIsInstance (linear_svm .save (), LinearSVM , msg = "Save failed" )
@@ -126,14 +127,15 @@ def test_04_load(self):
126
127
127
128
def test_05_re_predict (self ):
128
129
self .assertIs (linear_svm .predict (train_set [0 ]).dtype , np .dtype ("float32" ), "Re-Predict failed" )
130
+ self .assertIs (linear_svm .predict_classes (cv_set [0 ]).dtype , np .dtype ("int32" ), "Re-Predict classes failed" )
129
131
self .assertIs (linear_svm .predict_classes (test_set [0 ]).dtype , np .dtype ("int32" ), "Re-Predict classes failed" )
130
132
131
133
def test_06_re_evaluate (self ):
132
- self .assertEqual (len (linear_svm .evaluate (* train_set , * test_set )), 3 , "Re-Evaluation failed" )
134
+ self .assertEqual (len (linear_svm .evaluate (* train_set , * cv_set , * test_set )), 3 , "Re-Evaluation failed" )
133
135
134
136
def test_07_re_train (self ):
135
137
self .assertIsInstance (
136
- linear_svm .fit (* train_set , * test_set , verbose = 0 ), LinearSVM ,
138
+ linear_svm .fit (* train_set , * cv_set , verbose = 0 ), LinearSVM ,
137
139
msg = "Re-Train failed"
138
140
)
139
141
0 commit comments