Skip to content

Commit

Permalink
changes for test() in adaboost.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiran02 committed Mar 15, 2019
1 parent 284399a commit e212d59
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions roc_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
def test():
# D = np.mat(np.ones((5, 1)) / 5)
# data_mat, class_labels = load_sim_data()
# print(data_mat.shape)
# result = build_stump(data_mat, class_labels, D)
# print(result)
# classifier_array, agg_class_est = ada_boost_train_ds(data_mat, class_labels, 9)
# print(classifier_array, agg_class_est)
data_mat, class_labels = load_data_set('../../../input/7.AdaBoost/horseColicTraining2.txt')
print(data_mat.shape, len(class_labels))
weak_class_arr, agg_class_est = ada_boost_train_ds(data_mat, class_labels, 40)
print(weak_class_arr, '\n-----\n', agg_class_est.T)
'''
agg_class_est是m*1维的矩阵,需先对其转置,再执行plot_roc()
'''
plot_roc(agg_class_est.T, class_labels)
data_arr_test, label_arr_test = load_data_set("../../../input/7.AdaBoost/horseColicTest2.txt")
m = np.shape(data_arr_test)[0]
predicting10 = ada_classify(data_arr_test, weak_class_arr)
err_arr = np.mat(np.ones((m, 1)))
# 测试:计算总样本数,错误样本数,错误率
print(m,
err_arr[predicting10 != np.mat(label_arr_test).T].sum(),
err_arr[predicting10 != np.mat(label_arr_test).T].sum() / m
)

0 comments on commit e212d59

Please sign in to comment.