Skip to content

Commit

Permalink
Merge pull request #242 from htm-community/mnist-example
Browse files Browse the repository at this point in the history
MNIST example
  • Loading branch information
breznak authored Feb 28, 2019
2 parents 9123f7c + d26299a commit b39ffa9
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 34 deletions.
4 changes: 3 additions & 1 deletion external/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,6 @@ if(match)
endif()



#################
# mnist data
include(mnist_data.cmake)
1 change: 1 addition & 0 deletions external/bootstrap.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ set(EXTERNAL_INCLUDES
${yaml-cpp_INCLUDE_DIRS}
${Boost_INCLUDE_DIRS}
${eigen_INCLUDE_DIRS}
${mnist_INCLUDE_DIRS}
${REPOSITORY_DIR}/external/common/include
)

49 changes: 49 additions & 0 deletions external/mnist_data.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -----------------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2016, Numenta, Inc. Unless you have purchased from
# Numenta, Inc. a separate commercial license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# -----------------------------------------------------------------------------

# Fetch MNIST dataset from online archive
#
if(EXISTS ${REPOSITORY_DIR}/build/ThirdParty/share/mnist.zip)
set(URL ${REPOSITORY_DIR}/build/ThirdParty/share/mnist.zip)
else()
set(URL "https://github.com/wichtounet/mnist/archive/master.zip")
set(HASH "855cb8c60f84e2fc6bea08c4a9df9a3cbd6230bddc55def635a938665c512ffc")
endif()

message(STATUS "obtaining MNIST data")
include(DownloadProject/DownloadProject.cmake)
download_project(PROJ mnist
PREFIX ${EP_BASE}/mnist_data
URL ${URL}
URL_HASH SHA256=${HASH}
UPDATE_DISCONNECTED 1
# QUIET
)

# No build. This is a data only package
# But we do need to run its CMakeLists.txt to unpack the files.

add_subdirectory(${mnist_SOURCE_DIR}/example/ ${mnist_BINARY_DIR})
FILE(APPEND "${EXPORT_FILE_NAME}" "mnist_INCLUDE_DIRS@@@${mnist_SOURCE_DIR}/include\n")
FILE(APPEND "${EXPORT_FILE_NAME}" "mnist_SOURCE_DIR@@@${mnist_SOURCE_DIR}\n")

# includes will be found with #include <mnist/mnist_reader_less.hpp>
# data will be found in folder ${mnist_SOURCE_DIR}
41 changes: 30 additions & 11 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ set(utils_files
)

set(examples_files
examples/hotgym/Hotgym.cpp
examples/hotgym/Hotgym.cpp # contains conflicting main()
examples/hotgym/HelloSPTP.cpp
examples/hotgym/HelloSPTP.hpp
examples/mnist/MNIST_SP.cpp
)

#set up file tabs in Visual Studio
Expand Down Expand Up @@ -261,7 +262,6 @@ add_library(${src_objlib} OBJECT
${regions_files}
${types_files}
${utils_files}
${examples_files}
)
# shared libraries need PIC
target_compile_options( ${src_objlib} PUBLIC ${INTERNAL_CXX_FLAGS})
Expand Down Expand Up @@ -348,13 +348,8 @@ add_subdirectory(test)
#
## Setup benchmark_hotgym
#
source_group("examples" FILES
examples/hotgym/Hotgym.cpp
examples/hotgym/HelloSPTP.hpp
)

set(src_executable_hotgym benchmark_hotgym)
add_executable(${src_executable_hotgym} examples/hotgym/Hotgym.cpp examples/hotgym/HelloSPTP.hpp)
add_executable(${src_executable_hotgym} examples/hotgym/Hotgym.cpp examples/hotgym/HelloSPTP.hpp examples/hotgym/HelloSPTP.cpp)
if(MSVC)
# for Windows, link with the static library
target_link_libraries(${src_executable_hotgym}
Expand All @@ -381,21 +376,46 @@ add_custom_target(hotgym
VERBATIM)


#########################################################
## MNIST Spatial Pooler Example
#
set(src_executable_mnistsp mnist_sp)
add_executable(${src_executable_mnistsp} examples/mnist/MNIST_SP.cpp)
target_link_libraries(${src_executable_mnistsp}
${core_library}
${COMMON_OS_LIBS}
)
target_compile_options(${src_executable_mnistsp} PUBLIC ${INTERNAL_CXX_FLAGS})
target_compile_definitions(${src_executable_mnistsp} PRIVATE ${COMMON_COMPILER_DEFINITIONS})
# Pass MNIST data directory to main.cpp
target_compile_definitions(${src_executable_mnistsp} PRIVATE MNIST_DATA_LOCATION=${mnist_SOURCE_DIR})
target_include_directories(${src_executable_mnistsp} PRIVATE
${CORE_LIB_INCLUDES}
${EXTERNAL_INCLUDES}
)
add_custom_target(mnist
COMMAND ${src_executable_mnistsp}
DEPENDS ${src_executable_mnistsp}
COMMENT "Executing ${src_executable_mnistsp}"
VERBATIM)


############ INSTALL ######################################
#
# Install targets into CMAKE_INSTALL_PREFIX
#
install(TARGETS
${core_library}
${src_executable_hotgym}
RUNTIME DESTINATION bin
${src_executable_mnistsp}
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)

if(MSVC)
else()
install(TARGETS
${src_lib_shared}
${src_lib_shared}
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)
Expand All @@ -415,7 +435,6 @@ install(DIRECTORY "${REPOSITORY_DIR}/external/common/include/"
MESSAGE_NEVER
DESTINATION include)



#
# `make package` results in
Expand Down
1 change: 1 addition & 0 deletions src/examples/hotgym/HelloSPTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ using TM = nupic::algorithms::temporal_memory::TemporalMemory;
using nupic::algorithms::anomaly::Anomaly;
using nupic::algorithms::anomaly::AnomalyMode;


// work-load
Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool useTP, bool useBackTM, bool useTM, const UInt COLS, const UInt DIM_INPUT, const UInt CELLS) {
#ifndef NDEBUG
Expand Down
188 changes: 188 additions & 0 deletions src/examples/mnist/MNIST_SP.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/* ---------------------------------------------------------------------
* Copyright (C) 2018, David McDougall.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
* ----------------------------------------------------------------------
*/

/**
* Solving the MNIST dataset with Spatial Pooler.
*
* This consists of a simple black & white image encoder, a spatial pool, and an
* SDR classifier. The task is to recognise images of hand written numbers 0-9.
* This should score at least 95%.
*/

#include <algorithm>
#include <cstdint> //uint8_t
#include <iostream>
#include <vector>

#include <nupic/algorithms/SpatialPooler.hpp>
#include <nupic/algorithms/SDRClassifier.hpp>
#include <nupic/algorithms/ClassifierResult.hpp>
#include <nupic/utils/SdrMetrics.hpp>

#include <mnist/mnist_reader.hpp> // MNIST data itself + read methods, namespace mnist::

namespace examples {

using namespace std;
using namespace nupic;

using nupic::algorithms::spatial_pooler::SpatialPooler;
using nupic::algorithms::sdr_classifier::SDRClassifier;
using nupic::algorithms::cla_classifier::ClassifierResult;

class MNIST {

private:
SpatialPooler sp;
SDR input;
SDR columns;
SDRClassifier clsr;
mnist::MNIST_dataset<std::vector, std::vector<uint8_t>, uint8_t> dataset;

public:
UInt verbosity = 1;
const UInt train_dataset_iterations = 1u;


void setup() {

input.initialize({28, 28});
sp.initialize(
/* inputDimensions */ input.dimensions,
/* columnDimensions */ {28, 28}, //mostly affects speed, to some threshold accuracy only marginally
/* potentialRadius */ 5u,
/* potentialPct */ 0.5f,
/* globalInhibition */ false,
/* localAreaDensity */ 0.20f, //% active bits, //quite important variable (speed x accuracy)
/* numActiveColumnsPerInhArea */ -1,
/* stimulusThreshold */ 6u,
/* synPermInactiveDec */ 0.005f,
/* synPermActiveInc */ 0.01f,
/* synPermConnected */ 0.4f,
/* minPctOverlapDutyCycles */ 0.001f,
/* dutyCyclePeriod */ 1402,
/* boostStrength */ 2.5f, //boosting does help
/* seed */ 93u,
/* spVerbosity */ 1u,
/* wrapAround */ false); //wrap is false for this problem

columns.initialize({sp.getNumColumns()});

clsr.initialize(
/* steps */ {0},
/* alpha */ .001,
/* actValueAlpha */ .3,
verbosity);

dataset = mnist::read_dataset<std::vector, std::vector, uint8_t, uint8_t>(string("../ThirdParty/mnist_data/mnist-src/")); //from CMake
}

void train() {
// Train

if(verbosity)
cout << "Training for " << (train_dataset_iterations * dataset.training_labels.size())
<< " cycles ..." << endl;
size_t i = 0;

SDR_Metrics inputStats(input, 1402);
SDR_Metrics columnStats(columns, 1402);

for(auto epoch = 0u; epoch < train_dataset_iterations; epoch++) {
NTA_INFO << "epoch " << epoch;
// Shuffle the training data.
vector<UInt> index( dataset.training_labels.size() );
index.assign(dataset.training_labels.cbegin(), dataset.training_labels.cend());
Random().shuffle( index.begin(), index.end() );

for(const auto idx : index) { // index = order of label (shuffeled)
// Get the input & label
const auto image = dataset.training_images.at(idx);
const UInt label = dataset.training_labels.at(idx);

// Compute & Train
input.setDense( image );
sp.compute(input, true, columns);
ClassifierResult result;
clsr.compute(sp.getIterationNum(), columns.getFlatSparse(),
/* bucketIdxList */ {label},
/* actValueList */ {(Real)label},
/* category */ true,
/* learn */ true,
/* infer */ false,
&result);
if( verbosity && (++i % 1000 == 0) ) cout << "." << flush;
}
if( verbosity ) cout << endl;
}
cout << "epoch ended" << endl;
cout << inputStats << endl;
cout << columnStats << endl;
}

void test() {
// Test
Real score = 0;
UInt n_samples = 0;
if(verbosity)
cout << "Testing for " << dataset.test_labels.size() << " cycles ..." << endl;
for(UInt i = 0; i < dataset.test_labels.size(); i++) {
// Get the input & label
const auto image = dataset.test_images.at(i);
const UInt label = dataset.test_labels.at(i);

// Compute
input.setDense( image );
sp.compute(input, false, columns);
sp.stripUnlearnedColumns(columns);
ClassifierResult result;
clsr.compute(sp.getIterationNum(), columns.getFlatSparse(),
/* bucketIdxList */ {},
/* actValueList */ {},
/* category */ true,
/* learn */ false,
/* infer */ true,
&result);
// Check results
for(auto iter : result) {
if( iter.first == 0 ) {
const auto *pdf = iter.second;
const auto max = std::max_element(pdf->cbegin(), pdf->cend());
const UInt cls = max - pdf->cbegin();
if(cls == label)
score += 1;
n_samples += 1;
}
}
if( verbosity && i % 1000 == 0 ) cout << "." << flush;
}
if( verbosity ) cout << endl;
cout << "Score: " << 100.0 * score / n_samples << "% " << endl;
}

}; // End class MNIST
} // End namespace examples

int main(int argc, char **argv) {
examples::MNIST m;
m.setup();
m.train();
m.test();

return 0;
}

6 changes: 4 additions & 2 deletions src/nupic/algorithms/Connections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ void Connections::raisePermanencesToThreshold(
if( segmentThreshold == 0 ) //no synapses requested to be connected, done.
return;

NTA_ASSERT(segment < segments_.size()) << "Accessing segment out of bounds.";
auto &segData = segments_[segment];
if( segData.numConnected >= segmentThreshold ) //the segment already satisfies the requirement, done.
return;
Expand All @@ -496,7 +497,7 @@ void Connections::raisePermanencesToThreshold(
// permance by such that it becomes a connected synapse.
// After that there will be at least N synapses connected.

auto minPermSynPtr = synapses.begin() + threshold - 1;
auto minPermSynPtr = synapses.begin() + threshold - 1; //threshold is ensured to be >=1 by condition at very beginning if(thresh == 0)...
// Do a partial sort, it's faster than a full sort. Only minPermSynPtr is in
// its final sorted position.
const auto permanencesGreater = [&](const Synapse &A, const Synapse &B)
Expand All @@ -515,8 +516,9 @@ void Connections::raisePermanencesToThreshold(

void Connections::bumpSegment(const Segment segment, const Permanence delta) {
const vector<Synapse> &synapses = synapsesForSegment(segment);
for( const auto &syn : synapses )
for( const auto &syn : synapses ) {
updateSynapsePermanence(syn, synapses_[syn].permanence + delta);
}
}


Expand Down
Loading

0 comments on commit b39ffa9

Please sign in to comment.