Skip to content

Commit

Permalink
update param
Browse files Browse the repository at this point in the history
  • Loading branch information
lizeyan committed Jun 5, 2017
1 parent 5103ba8 commit 42f73bf
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions run_dbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main():
expanded_train_data = np.expand_dims(flatten_train_data.reshape((-1,) + IMAGE_SIZE), -1)
expanded_test_data = np.expand_dims(flatten_test_data.reshape((-1, ) + IMAGE_SIZE), -1)

dbn = SupervisedDBNClassification(hidden_layers_structure=[256, 256], learning_rate_rbm=0.0005, learning_rate=0.001, n_epochs_rbm=50, n_iter_backprop=100, batch_size=128, activation_function='relu', dropout_p=0.2)
dbn = SupervisedDBNClassification(hidden_layers_structure=[4096, 4096], learning_rate_rbm=0.001, learning_rate=0.01, n_epochs_rbm=20, n_iter_backprop=100, batch_size=128, activation_function='relu', dropout_p=0.2)
dbn.fit(flatten_train_data, train_label)
evaluate(np.asarray(list(dbn.predict(flatten_test_data))), test_label, "DBN")

Expand Down Expand Up @@ -55,10 +55,11 @@ def example():
batch_size=32,
activation_function='relu',
dropout_p=0.2)
print(X_train.shape, Y_train.shape)
classifier.fit(X_train, Y_train)

# Test
Y_pred = classifier.predict(X_test)
Y_pred = np.asarray(list(classifier.predict(X_test)))
print('Done.\nAccuracy: %f' % accuracy_score(Y_test, Y_pred))

if __name__ == '__main__':
Expand Down

0 comments on commit 42f73bf

Please sign in to comment.