Skip to content

Commit

Permalink
run with train instead of test to eval
Browse files Browse the repository at this point in the history
just to make sure it's not some weird bug with that part.

tensorboard log-dir: log_data/May_18/May_18_23:19:37/
  • Loading branch information
PeterMitrano committed May 19, 2017
1 parent 03e3d80 commit 1d738e3
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions pcanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __init__(self, image_batch, hyperparams, info):
self.binary_quantize_viz = tf.reshape(tf.expand_dims(self.binary_quantize, axis=4), [-1, info.IMAGE_W, info.IMAGE_H, 1])
self.binary_encoded_viz = tf.expand_dims(self.binary_encoded[:, 1, :, :], axis=3)
tf.summary.image('quantized', self.binary_quantize_viz, max_outputs=10)
tf.summary.image('encoded', self.binary_encoded_viz, max_outputs=10)

with tf.name_scope("histograms"):
self.n_bins = k = pow(2, l2)
Expand Down Expand Up @@ -199,14 +198,13 @@ def main():
train_summary = tf.summary.merge_all('train')
test_summary = tf.summary.merge_all('test')

# q = sess.run(m.binary_encoded)
# print(q[0])

# extract PCA features from training set
train_pcanet_features, train_labels, summary = sess.run([m.output_features, train_label_batch, merged_summary])
writer.add_summary(summary, 0)

# q = sess.run(m.x_eig1)
# np.savetxt('eig.csv', np.squeeze(q))
# exit(0)

# train linear SVM
svm = LinearSVC(C=1, fit_intercept=False)
svm.fit(train_pcanet_features, train_labels)
Expand All @@ -219,11 +217,12 @@ def main():
# switch to test set, compute PCA filters, and score with learned SVM parameters
scores = []
test_labels = sess.run(test_label_batch)
m.image_batch = test_image_batch
# m.image_batch = test_image_batch
for i in range(4):
test_pcanet_features = sess.run(m.output_features)

score = svm.score(test_pcanet_features, test_labels)
# score = svm.score(test_pcanet_features, test_labels)
score = svm.score(train_pcanet_features, train_labels)
scores.append(score)

print("batch test score:", score)
Expand Down

0 comments on commit 1d738e3

Please sign in to comment.