Skip to content

Commit

Permalink
ENH Improve script output display
Browse files Browse the repository at this point in the history
  • Loading branch information
arjoly committed Jan 12, 2015
1 parent fa0ceee commit b0dba06
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions benchmarks/bench_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,15 @@
10 classes - digits from 0 to 9 from their raw images. By contrast to the
covertype dataset, the feature space is homogenous.
Example of output :
[..]
Classification performance:
===========================
Classifier train-time test-time error-rate
--------------------------------------------
Nystroem-SVM 118.4512s 0.8624s 0.0231
ExtraTrees 47.5039s 0.5075s 0.0288
RandomForest 46.2317s 0.4342s 0.0304
SampledRBF-SVM 137.0131s 0.8688s 0.0488
CART 21.0593s 0.0134s 0.1214
Classifier train-time test-time error-rate
------------------------------------------------------------
Nystroem-SVM 115.31s 1.23s 0.0227
ExtraTrees 55.90s 1.27s 0.0288
"""
from __future__ import division, print_function
Expand Down Expand Up @@ -83,9 +80,9 @@ def load_data(dtype=np.float32, order='F'):
'ExtraTrees': ExtraTreesClassifier(n_estimators=100),
'RandomForest': RandomForestClassifier(n_estimators=100),
'Nystroem-SVM':
make_pipeline(Nystroem(gamma=0.031, n_components=1000), LinearSVC(C=100)),
make_pipeline(Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=100)),
'SampledRBF-SVM':
make_pipeline(RBFSampler(gamma=0.031, n_components=1000), LinearSVC(C=100))
make_pipeline(RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100))
}


Expand Down Expand Up @@ -152,12 +149,12 @@ def load_data(dtype=np.float32, order='F'):
print()
print("Classification performance:")
print("===========================")
print("%s %s %s %s"
% ("Classifier".ljust(16), "train-time", "test-time", "error-rate"))
print("-" * 44)
print("{0: <24} {1: >10} {2: >11} {3: >12}"
"".format("Classifier ", "train-time", "test-time", "error-rate"))
print("-" * 60)
for name in sorted(args["classifiers"], key=error.get):
print("%s %s %s %s" % (name.ljust(16),
("%.4fs" % train_time[name]).center(10),
("%.4fs" % test_time[name]).center(10),
("%.4f" % error[name]).center(10)))

print("{0: <23} {1: >10.2f}s {2: >10.2f}s {3: >12.4f}"
"".format(name, train_time[name], test_time[name], error[name]))

print()

0 comments on commit b0dba06

Please sign in to comment.