Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNIST example #242

Merged
merged 55 commits into from
Feb 28, 2019
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
92b1169
WIP - Initial commit to new branch
ctrl-z-9000-times Dec 1, 2018
39c34ad
MNIST -> 80% accuracy
ctrl-z-9000-times Dec 10, 2018
cca50fc
Merge branch 'master' into mnist
ctrl-z-9000-times Dec 18, 2018
ec8c24e
Merge branch 'sp-stats' into mnist
ctrl-z-9000-times Dec 26, 2018
fbeabc7
Score 94%
ctrl-z-9000-times Dec 26, 2018
1dfe98f
SDR-Classifier fix missing header.
ctrl-z-9000-times Dec 26, 2018
81fd5fe
MNIST CMake materials
ctrl-z-9000-times Dec 26, 2018
752a122
Merge branch 'master_community' into mnist
breznak Jan 31, 2019
3193892
SpatialPooler: revert to upstream master version
breznak Feb 1, 2019
77acef0
MNIST: update example
breznak Feb 1, 2019
b944b90
Connections: add asserts
breznak Feb 1, 2019
ac3b020
Merge remote-tracking branch 'community/master' into mnist-example
breznak Feb 7, 2019
9a58b03
Update merge resolution
breznak Feb 7, 2019
f973f24
Pull MNIST experiment
ctrl-z-9000-times Feb 4, 2019
d05d392
MNIST: working example
breznak Feb 8, 2019
780083a
Connections: raisePermanenceToThreshold bug
breznak Feb 8, 2019
d2d4633
SpatialPooler: formating of doc
breznak Feb 9, 2019
49677ba
MNIST: try SP params
breznak Feb 9, 2019
1bf08eb
examples/Hotgym rename to HotgymMain
breznak Feb 9, 2019
af35e2a
example/MNIST make namespace examples
breznak Feb 9, 2019
c6fc009
SDRClassifier: add initialize() method
breznak Feb 9, 2019
25b076e
Serializable: include filesystem helper header
breznak Feb 9, 2019
97e204d
example: make classes MNIST, HelloSPTP, use namespace
breznak Feb 9, 2019
0a46993
CMake: download MNIST repo during configure WIP
breznak Feb 9, 2019
382cc84
fixed up the external part of this PR
dkeeney Feb 15, 2019
4e0f026
Made mnist-example optional.
dkeeney Feb 17, 2019
3221315
Merge remote-tracking branch 'community/master' into mnist-example
breznak Feb 18, 2019
26b12c8
MNIST: use 3rd party repo for data and data access methods
breznak Feb 18, 2019
4f03d5e
Revert "Made mnist-example optional."
breznak Feb 18, 2019
f2c8718
MNIST small fixes
breznak Feb 18, 2019
508282e
MNIST: external repo updated
breznak Feb 18, 2019
c0aaefc
fixes
breznak Feb 18, 2019
62a63a9
Merge branch 'master_community' into mnist-example
breznak Feb 18, 2019
7c1712a
Mnist cleanup
breznak Feb 19, 2019
f9eae0c
MNIST: try full headers
breznak Feb 22, 2019
c60fdb2
cleanup
breznak Feb 22, 2019
d5031ed
TMP Cmake disable shared so build
breznak Feb 22, 2019
684bf70
MNIST tuning params, WIP
breznak Feb 22, 2019
7250138
MNIST param tuning 2
breznak Feb 22, 2019
920eff7
MNIST scores > 30%
ctrl-z-9000-times Feb 24, 2019
acaca33
Merge branch 'master_community' into mnist-example
breznak Feb 25, 2019
149a113
cmake fixes
breznak Feb 25, 2019
f728658
Merge remote-tracking branch 'community/mnist-example' into mnist-exa…
breznak Feb 25, 2019
c93486b
CI: skip TP performance on Windows
breznak Feb 25, 2019
d0fe52e
debugging details
breznak Feb 26, 2019
d75188e
CMAke: do not bundle examples with main library
breznak Feb 26, 2019
b928fc3
fix: skip performance test for Windows in CI
breznak Feb 26, 2019
a14c5f4
Connections: resolve merge conflicts
breznak Feb 26, 2019
907fc3d
another try for passing performance CI on Windows
breznak Feb 26, 2019
4bf889b
Merge branch 'master_community' into mnist-example
breznak Feb 26, 2019
4cee012
MNIST: use 2D input, SP and smaller columns
breznak Feb 26, 2019
41a68ee
Revert "TMP Cmake disable shared so build"
breznak Feb 26, 2019
a4b2732
MNIST: tuned params 45%
breznak Feb 26, 2019
4cc6cd8
MNIST: stript unlearned cols on inference
breznak Feb 26, 2019
d26299a
MNIST: local inh 43%
breznak Feb 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion external/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,6 @@ if(match)
endif()



#################
# mnist data
include(mnist_data.cmake)
1 change: 1 addition & 0 deletions external/bootstrap.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ set(EXTERNAL_INCLUDES
${yaml-cpp_INCLUDE_DIRS}
${Boost_INCLUDE_DIRS}
${eigen_INCLUDE_DIRS}
${mnist_INCLUDE_DIRS}
${REPOSITORY_DIR}/external/common/include
)

49 changes: 49 additions & 0 deletions external/mnist_data.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -----------------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2016, Numenta, Inc. Unless you have purchased from
# Numenta, Inc. a separate commercial license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# -----------------------------------------------------------------------------

# Fetch MNIST dataset from online archive
#
if(EXISTS ${REPOSITORY_DIR}/build/ThirdParty/share/mnist.zip)
set(URL ${REPOSITORY_DIR}/build/ThirdParty/share/mnist.zip)
else()
set(URL "https://github.com/wichtounet/mnist/archive/master.zip")
breznak marked this conversation as resolved.
Show resolved Hide resolved
set(HASH "855cb8c60f84e2fc6bea08c4a9df9a3cbd6230bddc55def635a938665c512ffc")
endif()

message(STATUS "obtaining MNIST data")
include(DownloadProject/DownloadProject.cmake)
breznak marked this conversation as resolved.
Show resolved Hide resolved
download_project(PROJ mnist
PREFIX ${EP_BASE}/mnist_data
URL ${URL}
URL_HASH SHA256=${HASH}
UPDATE_DISCONNECTED 1
# QUIET
)

# No build. This is a data only package
# But we do need to run its CMakeLists.txt to unpack the files.

add_subdirectory(${mnist_SOURCE_DIR}/example/ ${mnist_BINARY_DIR})
FILE(APPEND "${EXPORT_FILE_NAME}" "mnist_INCLUDE_DIRS@@@${mnist_SOURCE_DIR}/include\n")
FILE(APPEND "${EXPORT_FILE_NAME}" "mnist_SOURCE_DIR@@@${mnist_SOURCE_DIR}\n")
breznak marked this conversation as resolved.
Show resolved Hide resolved

# includes will be found with #include <mnist/mnist_reader_less.hpp>
# data will be found in folder ${mnist_SOURCE_DIR}
48 changes: 34 additions & 14 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ set(utils_files
)

set(examples_files
examples/hotgym/Hotgym.cpp
# examples/hotgym/Hotgym.cpp # contains conflicting main()
examples/hotgym/HelloSPTP.cpp
examples/hotgym/HelloSPTP.hpp
# examples/mnist/MNIST_SP.cpp
)

#set up file tabs in Visual Studio
Expand Down Expand Up @@ -270,7 +271,7 @@ add_library(${src_objlib} OBJECT
${regions_files}
${types_files}
${utils_files}
${examples_files}
${examples_files}
breznak marked this conversation as resolved.
Show resolved Hide resolved
)
# shared libraries need PIC
target_compile_options( ${src_objlib} PUBLIC ${INTERNAL_CXX_FLAGS})
Expand Down Expand Up @@ -312,7 +313,7 @@ merge_static_libraries(${core_library} "${src_combined_nupiccore_source_archives
#
# For Linux, OSx; the .so file will contain all symbols.
#
if(MSVC)
if(true)
breznak marked this conversation as resolved.
Show resolved Hide resolved
# NOTE: this disables shared lib for Windows.
else()
# First extract static libraries into an Object Library
Expand Down Expand Up @@ -357,14 +358,9 @@ add_subdirectory(test)
#
## Setup benchmark_hotgym
#
source_group("examples" FILES
examples/hotgym/Hotgym.cpp
examples/hotgym/HelloSPTP.hpp
)

set(src_executable_hotgym benchmark_hotgym)
add_executable(${src_executable_hotgym} examples/hotgym/Hotgym.cpp examples/hotgym/HelloSPTP.hpp)
if(MSVC)
add_executable(${src_executable_hotgym} examples/hotgym/Hotgym.cpp examples/hotgym/HelloSPTP.hpp examples/hotgym/HelloSPTP.cpp)
if(true)
breznak marked this conversation as resolved.
Show resolved Hide resolved
# for Windows, link with the static library
target_link_libraries(${src_executable_hotgym}
${core_library}
Expand All @@ -390,21 +386,46 @@ add_custom_target(hotgym
VERBATIM)


#########################################################
## MNIST Spatial Pooler Example
#
set(src_executable_mnistsp mnist_sp)
add_executable(${src_executable_mnistsp} examples/mnist/MNIST_SP.cpp)
target_link_libraries(${src_executable_mnistsp}
${core_library}
${COMMON_OS_LIBS}
)
target_compile_options(${src_executable_mnistsp} PUBLIC ${INTERNAL_CXX_FLAGS})
target_compile_definitions(${src_executable_mnistsp} PRIVATE ${COMMON_COMPILER_DEFINITIONS})
# Pass MNIST data directory to main.cpp
target_compile_definitions(${src_executable_mnistsp} PRIVATE MNIST_DATA_LOCATION=${mnist_SOURCE_DIR})
target_include_directories(${src_executable_mnistsp} PRIVATE
${CORE_LIB_INCLUDES}
${EXTERNAL_INCLUDES}
)
add_custom_target(mnist
COMMAND ${src_executable_mnistsp}
DEPENDS ${src_executable_mnistsp}
COMMENT "Executing ${src_executable_mnistsp}"
VERBATIM)


############ INSTALL ######################################
#
# Install targets into CMAKE_INSTALL_PREFIX
#
install(TARGETS
${core_library}
${src_executable_hotgym}
RUNTIME DESTINATION bin
${src_executable_mnistsp}
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)

if(MSVC)
if(true)
breznak marked this conversation as resolved.
Show resolved Hide resolved
else()
install(TARGETS
${src_lib_shared}
${src_lib_shared}
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)
Expand All @@ -424,7 +445,6 @@ install(DIRECTORY "${REPOSITORY_DIR}/external/common/include/"
MESSAGE_NEVER
DESTINATION include)



#
# `make package` results in
Expand Down
1 change: 1 addition & 0 deletions src/examples/hotgym/HelloSPTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ using TM = nupic::algorithms::temporal_memory::TemporalMemory;
using nupic::algorithms::anomaly::Anomaly;
using nupic::algorithms::anomaly::AnomalyMode;


// work-load
Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool useTP, bool useBackTM, bool useTM, const UInt COLS, const UInt DIM_INPUT, const UInt CELLS) {
#ifndef NDEBUG
Expand Down
187 changes: 187 additions & 0 deletions src/examples/mnist/MNIST_SP.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/* ---------------------------------------------------------------------
* Copyright (C) 2018, David McDougall.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
* ----------------------------------------------------------------------
*/

/**
* Solving the MNIST dataset with Spatial Pooler.
*
* This consists of a simple black & white image encoder, a spatial pool, and an
* SDR classifier. The task is to recognise images of hand written numbers 0-9.
* This should score at least 95%.
*/

#include <algorithm>
#include <cstdint> //uint8_t
#include <iostream>
#include <vector>

#include <nupic/algorithms/SpatialPooler.hpp>
#include <nupic/algorithms/SDRClassifier.hpp>
breznak marked this conversation as resolved.
Show resolved Hide resolved
#include <nupic/algorithms/ClassifierResult.hpp>
#include <nupic/utils/SdrMetrics.hpp>

#include <mnist/mnist_reader.hpp> // MNIST data itself + read methods, namespace mnist::

namespace examples {

using namespace std;
breznak marked this conversation as resolved.
Show resolved Hide resolved
using namespace nupic;

using nupic::algorithms::spatial_pooler::SpatialPooler;
using nupic::algorithms::sdr_classifier::SDRClassifier;
using nupic::algorithms::cla_classifier::ClassifierResult;

class MNIST {

private:
SpatialPooler sp;
SDR input;
SDR columns;
SDRClassifier clsr;
mnist::MNIST_dataset<std::vector, std::vector<uint8_t>, uint8_t> dataset;

public:
UInt verbosity = 1;
const UInt train_dataset_iterations = 1u;


void setup() {

input.initialize({28 * 28});
sp.initialize(
/* inputDimensions */ input.dimensions,
/* columnDimensions */ {20 * 1000},
/* potentialRadius */ 999999u,
/* potentialPct */ 0.5f,
/* globalInhibition */ true,
/* localAreaDensity */ 0.015f,
/* numActiveColumnsPerInhArea */ -1,
/* stimulusThreshold */ 14u,
/* synPermInactiveDec */ 0.01f,
/* synPermActiveInc */ 0.05f,
/* synPermConnected */ 0.4f,
/* minPctOverlapDutyCycles */ 0.001f,
/* dutyCyclePeriod */ 1402,
/* boostStrength */ 1.0f,
/* seed */ 93u,
/* spVerbosity */ 1u,
/* wrapAround */ true);

columns.initialize({sp.getNumColumns()});

clsr.initialize(
/* steps */ {0},
/* alpha */ .001,
/* actValueAlpha */ .3,
verbosity);

dataset = mnist::read_dataset<std::vector, std::vector, uint8_t, uint8_t>(string("../ThirdParty/mnist_data/mnist-src/")); //from CMake
}

void train() {
// Train

if(verbosity)
cout << "Training for " << (train_dataset_iterations * dataset.training_labels.size())
<< " cycles ..." << endl;
size_t i = 0;

SDR_Metrics inputStats(input, 1402);
SDR_Metrics columnStats(columns, 1402);

for(auto epoch = 0u; epoch < train_dataset_iterations; epoch++) {
NTA_INFO << "epoch " << epoch;
// Shuffle the training data.
vector<UInt> index( dataset.training_labels.size() );
index.assign(dataset.training_labels.cbegin(), dataset.training_labels.cend());
Random().shuffle( index.begin(), index.end() );

for(const auto idx : index) { // index = order of label (shuffeled)
// Get the input & label
const auto image = dataset.training_images.at(idx);
const UInt label = dataset.training_labels.at(idx);

// Compute & Train
input.setDense( image );
sp.compute(input, true, columns);
ClassifierResult result;
clsr.compute(sp.getIterationNum(), columns.getFlatSparse(),
/* bucketIdxList */ {label},
/* actValueList */ {(Real)label},
/* category */ true,
/* learn */ true,
/* infer */ false,
&result);
if( verbosity && (++i % 1000 == 0) ) cout << "." << flush;
}
if( verbosity ) cout << endl;
}
cout << "epoch ended" << endl;
cout << inputStats << endl;
cout << columnStats << endl;
}

void test() {
// Test
Real score = 0;
UInt n_samples = 0;
if(verbosity)
cout << "Testing for " << dataset.test_labels.size() << " cycles ..." << endl;
for(UInt i = 0; i < dataset.test_labels.size(); i++) {
// Get the input & label
const auto image = dataset.test_images.at(i);
const UInt label = dataset.test_labels.at(i);

// Compute
input.setDense( image );
sp.compute(input, false, columns);
ClassifierResult result;
clsr.compute(sp.getIterationNum(), columns.getFlatSparse(),
/* bucketIdxList */ {},
/* actValueList */ {},
/* category */ true,
/* learn */ false,
/* infer */ true,
&result);
// Check results
for(auto iter : result) {
if( iter.first == 0 ) {
const auto *pdf = iter.second;
const auto max = std::max_element(pdf->cbegin(), pdf->cend());
const UInt cls = max - pdf->cbegin();
if(cls == label)
score += 1;
n_samples += 1;
}
}
if( verbosity && i % 1000 == 0 ) cout << "." << flush;
}
if( verbosity ) cout << endl;
cout << "Score: " << 100.0 * score / n_samples << "%% " << endl;
}

}; // End class MNIST
} // End namespace examples

int main(int argc, char **argv) {
examples::MNIST m;
m.setup();
m.train();
m.test();

return 0;
}

Loading