Skip to content

Commit b39ffa9

Browse files
authored
Merge pull request #242 from htm-community/mnist-example
MNIST example
2 parents 9123f7c + d26299a commit b39ffa9

File tree

13 files changed

+317
-34
lines changed

13 files changed

+317
-34
lines changed

external/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,6 @@ if(match)
102102
endif()
103103

104104

105-
105+
#################
106+
# mnist data
107+
include(mnist_data.cmake)

external/bootstrap.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ set(EXTERNAL_INCLUDES
9797
${yaml-cpp_INCLUDE_DIRS}
9898
${Boost_INCLUDE_DIRS}
9999
${eigen_INCLUDE_DIRS}
100+
${mnist_INCLUDE_DIRS}
100101
${REPOSITORY_DIR}/external/common/include
101102
)
102103

external/mnist_data.cmake

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -----------------------------------------------------------------------------
2+
# Numenta Platform for Intelligent Computing (NuPIC)
3+
# Copyright (C) 2016, Numenta, Inc. Unless you have purchased from
4+
# Numenta, Inc. a separate commercial license for this software code, the
5+
# following terms and conditions apply:
6+
#
7+
# This program is free software: you can redistribute it and/or modify
8+
# it under the terms of the GNU Affero Public License version 3 as
9+
# published by the Free Software Foundation.
10+
#
11+
# This program is distributed in the hope that it will be useful,
12+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
14+
# See the GNU Affero Public License for more details.
15+
#
16+
# You should have received a copy of the GNU Affero Public License
17+
# along with this program. If not, see http://www.gnu.org/licenses.
18+
#
19+
# http://numenta.org/licenses/
20+
# -----------------------------------------------------------------------------
21+
22+
# Fetch MNIST dataset from online archive
23+
#
24+
if(EXISTS ${REPOSITORY_DIR}/build/ThirdParty/share/mnist.zip)
25+
set(URL ${REPOSITORY_DIR}/build/ThirdParty/share/mnist.zip)
26+
else()
27+
set(URL "https://github.com/wichtounet/mnist/archive/master.zip")
28+
set(HASH "855cb8c60f84e2fc6bea08c4a9df9a3cbd6230bddc55def635a938665c512ffc")
29+
endif()
30+
31+
message(STATUS "obtaining MNIST data")
32+
include(DownloadProject/DownloadProject.cmake)
33+
download_project(PROJ mnist
34+
PREFIX ${EP_BASE}/mnist_data
35+
URL ${URL}
36+
URL_HASH SHA256=${HASH}
37+
UPDATE_DISCONNECTED 1
38+
# QUIET
39+
)
40+
41+
# No build. This is a data only package
42+
# But we do need to run its CMakeLists.txt to unpack the files.
43+
44+
add_subdirectory(${mnist_SOURCE_DIR}/example/ ${mnist_BINARY_DIR})
45+
FILE(APPEND "${EXPORT_FILE_NAME}" "mnist_INCLUDE_DIRS@@@${mnist_SOURCE_DIR}/include\n")
46+
FILE(APPEND "${EXPORT_FILE_NAME}" "mnist_SOURCE_DIR@@@${mnist_SOURCE_DIR}\n")
47+
48+
# includes will be found with #include <mnist/mnist_reader_less.hpp>
49+
# data will be found in folder ${mnist_SOURCE_DIR}

src/CMakeLists.txt

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,10 @@ set(utils_files
223223
)
224224

225225
set(examples_files
226-
examples/hotgym/Hotgym.cpp
226+
examples/hotgym/Hotgym.cpp # contains conflicting main()
227227
examples/hotgym/HelloSPTP.cpp
228228
examples/hotgym/HelloSPTP.hpp
229+
examples/mnist/MNIST_SP.cpp
229230
)
230231

231232
#set up file tabs in Visual Studio
@@ -261,7 +262,6 @@ add_library(${src_objlib} OBJECT
261262
${regions_files}
262263
${types_files}
263264
${utils_files}
264-
${examples_files}
265265
)
266266
# shared libraries need PIC
267267
target_compile_options( ${src_objlib} PUBLIC ${INTERNAL_CXX_FLAGS})
@@ -348,13 +348,8 @@ add_subdirectory(test)
348348
#
349349
## Setup benchmark_hotgym
350350
#
351-
source_group("examples" FILES
352-
examples/hotgym/Hotgym.cpp
353-
examples/hotgym/HelloSPTP.hpp
354-
)
355-
356351
set(src_executable_hotgym benchmark_hotgym)
357-
add_executable(${src_executable_hotgym} examples/hotgym/Hotgym.cpp examples/hotgym/HelloSPTP.hpp)
352+
add_executable(${src_executable_hotgym} examples/hotgym/Hotgym.cpp examples/hotgym/HelloSPTP.hpp examples/hotgym/HelloSPTP.cpp)
358353
if(MSVC)
359354
# for Windows, link with the static library
360355
target_link_libraries(${src_executable_hotgym}
@@ -381,21 +376,46 @@ add_custom_target(hotgym
381376
VERBATIM)
382377

383378

379+
#########################################################
380+
## MNIST Spatial Pooler Example
381+
#
382+
set(src_executable_mnistsp mnist_sp)
383+
add_executable(${src_executable_mnistsp} examples/mnist/MNIST_SP.cpp)
384+
target_link_libraries(${src_executable_mnistsp}
385+
${core_library}
386+
${COMMON_OS_LIBS}
387+
)
388+
target_compile_options(${src_executable_mnistsp} PUBLIC ${INTERNAL_CXX_FLAGS})
389+
target_compile_definitions(${src_executable_mnistsp} PRIVATE ${COMMON_COMPILER_DEFINITIONS})
390+
# Pass MNIST data directory to main.cpp
391+
target_compile_definitions(${src_executable_mnistsp} PRIVATE MNIST_DATA_LOCATION=${mnist_SOURCE_DIR})
392+
target_include_directories(${src_executable_mnistsp} PRIVATE
393+
${CORE_LIB_INCLUDES}
394+
${EXTERNAL_INCLUDES}
395+
)
396+
add_custom_target(mnist
397+
COMMAND ${src_executable_mnistsp}
398+
DEPENDS ${src_executable_mnistsp}
399+
COMMENT "Executing ${src_executable_mnistsp}"
400+
VERBATIM)
401+
402+
384403
############ INSTALL ######################################
385404
#
386405
# Install targets into CMAKE_INSTALL_PREFIX
387406
#
388407
install(TARGETS
389408
${core_library}
390409
${src_executable_hotgym}
391-
RUNTIME DESTINATION bin
410+
${src_executable_mnistsp}
411+
RUNTIME DESTINATION bin
392412
LIBRARY DESTINATION lib
393413
ARCHIVE DESTINATION lib)
394414

395415
if(MSVC)
396416
else()
397417
install(TARGETS
398-
${src_lib_shared}
418+
${src_lib_shared}
399419
RUNTIME DESTINATION bin
400420
LIBRARY DESTINATION lib
401421
ARCHIVE DESTINATION lib)
@@ -415,7 +435,6 @@ install(DIRECTORY "${REPOSITORY_DIR}/external/common/include/"
415435
MESSAGE_NEVER
416436
DESTINATION include)
417437

418-
419438

420439
#
421440
# `make package` results in

src/examples/hotgym/HelloSPTP.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ using TM = nupic::algorithms::temporal_memory::TemporalMemory;
5656
using nupic::algorithms::anomaly::Anomaly;
5757
using nupic::algorithms::anomaly::AnomalyMode;
5858

59+
5960
// work-load
6061
Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool useTP, bool useBackTM, bool useTM, const UInt COLS, const UInt DIM_INPUT, const UInt CELLS) {
6162
#ifndef NDEBUG

src/examples/mnist/MNIST_SP.cpp

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/* ---------------------------------------------------------------------
2+
* Copyright (C) 2018, David McDougall.
3+
*
4+
* This program is free software: you can redistribute it and/or modify
5+
* it under the terms of the GNU Affero Public License version 3 as
6+
* published by the Free Software Foundation.
7+
*
8+
* This program is distributed in the hope that it will be useful,
9+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
11+
* See the GNU Affero Public License for more details.
12+
*
13+
* You should have received a copy of the GNU Affero Public License
14+
* along with this program. If not, see http://www.gnu.org/licenses.
15+
* ----------------------------------------------------------------------
16+
*/
17+
18+
/**
19+
* Solving the MNIST dataset with Spatial Pooler.
20+
*
21+
* This consists of a simple black & white image encoder, a spatial pool, and an
22+
* SDR classifier. The task is to recognise images of hand written numbers 0-9.
23+
* This should score at least 95%.
24+
*/
25+
26+
#include <algorithm>
27+
#include <cstdint> //uint8_t
28+
#include <iostream>
29+
#include <vector>
30+
31+
#include <nupic/algorithms/SpatialPooler.hpp>
32+
#include <nupic/algorithms/SDRClassifier.hpp>
33+
#include <nupic/algorithms/ClassifierResult.hpp>
34+
#include <nupic/utils/SdrMetrics.hpp>
35+
36+
#include <mnist/mnist_reader.hpp> // MNIST data itself + read methods, namespace mnist::
37+
38+
namespace examples {
39+
40+
using namespace std;
41+
using namespace nupic;
42+
43+
using nupic::algorithms::spatial_pooler::SpatialPooler;
44+
using nupic::algorithms::sdr_classifier::SDRClassifier;
45+
using nupic::algorithms::cla_classifier::ClassifierResult;
46+
47+
class MNIST {
48+
49+
private:
50+
SpatialPooler sp;
51+
SDR input;
52+
SDR columns;
53+
SDRClassifier clsr;
54+
mnist::MNIST_dataset<std::vector, std::vector<uint8_t>, uint8_t> dataset;
55+
56+
public:
57+
UInt verbosity = 1;
58+
const UInt train_dataset_iterations = 1u;
59+
60+
61+
void setup() {
62+
63+
input.initialize({28, 28});
64+
sp.initialize(
65+
/* inputDimensions */ input.dimensions,
66+
/* columnDimensions */ {28, 28}, //mostly affects speed, to some threshold accuracy only marginally
67+
/* potentialRadius */ 5u,
68+
/* potentialPct */ 0.5f,
69+
/* globalInhibition */ false,
70+
/* localAreaDensity */ 0.20f, //% active bits, //quite important variable (speed x accuracy)
71+
/* numActiveColumnsPerInhArea */ -1,
72+
/* stimulusThreshold */ 6u,
73+
/* synPermInactiveDec */ 0.005f,
74+
/* synPermActiveInc */ 0.01f,
75+
/* synPermConnected */ 0.4f,
76+
/* minPctOverlapDutyCycles */ 0.001f,
77+
/* dutyCyclePeriod */ 1402,
78+
/* boostStrength */ 2.5f, //boosting does help
79+
/* seed */ 93u,
80+
/* spVerbosity */ 1u,
81+
/* wrapAround */ false); //wrap is false for this problem
82+
83+
columns.initialize({sp.getNumColumns()});
84+
85+
clsr.initialize(
86+
/* steps */ {0},
87+
/* alpha */ .001,
88+
/* actValueAlpha */ .3,
89+
verbosity);
90+
91+
dataset = mnist::read_dataset<std::vector, std::vector, uint8_t, uint8_t>(string("../ThirdParty/mnist_data/mnist-src/")); //from CMake
92+
}
93+
94+
void train() {
95+
// Train
96+
97+
if(verbosity)
98+
cout << "Training for " << (train_dataset_iterations * dataset.training_labels.size())
99+
<< " cycles ..." << endl;
100+
size_t i = 0;
101+
102+
SDR_Metrics inputStats(input, 1402);
103+
SDR_Metrics columnStats(columns, 1402);
104+
105+
for(auto epoch = 0u; epoch < train_dataset_iterations; epoch++) {
106+
NTA_INFO << "epoch " << epoch;
107+
// Shuffle the training data.
108+
vector<UInt> index( dataset.training_labels.size() );
109+
index.assign(dataset.training_labels.cbegin(), dataset.training_labels.cend());
110+
Random().shuffle( index.begin(), index.end() );
111+
112+
for(const auto idx : index) { // index = order of label (shuffeled)
113+
// Get the input & label
114+
const auto image = dataset.training_images.at(idx);
115+
const UInt label = dataset.training_labels.at(idx);
116+
117+
// Compute & Train
118+
input.setDense( image );
119+
sp.compute(input, true, columns);
120+
ClassifierResult result;
121+
clsr.compute(sp.getIterationNum(), columns.getFlatSparse(),
122+
/* bucketIdxList */ {label},
123+
/* actValueList */ {(Real)label},
124+
/* category */ true,
125+
/* learn */ true,
126+
/* infer */ false,
127+
&result);
128+
if( verbosity && (++i % 1000 == 0) ) cout << "." << flush;
129+
}
130+
if( verbosity ) cout << endl;
131+
}
132+
cout << "epoch ended" << endl;
133+
cout << inputStats << endl;
134+
cout << columnStats << endl;
135+
}
136+
137+
void test() {
138+
// Test
139+
Real score = 0;
140+
UInt n_samples = 0;
141+
if(verbosity)
142+
cout << "Testing for " << dataset.test_labels.size() << " cycles ..." << endl;
143+
for(UInt i = 0; i < dataset.test_labels.size(); i++) {
144+
// Get the input & label
145+
const auto image = dataset.test_images.at(i);
146+
const UInt label = dataset.test_labels.at(i);
147+
148+
// Compute
149+
input.setDense( image );
150+
sp.compute(input, false, columns);
151+
sp.stripUnlearnedColumns(columns);
152+
ClassifierResult result;
153+
clsr.compute(sp.getIterationNum(), columns.getFlatSparse(),
154+
/* bucketIdxList */ {},
155+
/* actValueList */ {},
156+
/* category */ true,
157+
/* learn */ false,
158+
/* infer */ true,
159+
&result);
160+
// Check results
161+
for(auto iter : result) {
162+
if( iter.first == 0 ) {
163+
const auto *pdf = iter.second;
164+
const auto max = std::max_element(pdf->cbegin(), pdf->cend());
165+
const UInt cls = max - pdf->cbegin();
166+
if(cls == label)
167+
score += 1;
168+
n_samples += 1;
169+
}
170+
}
171+
if( verbosity && i % 1000 == 0 ) cout << "." << flush;
172+
}
173+
if( verbosity ) cout << endl;
174+
cout << "Score: " << 100.0 * score / n_samples << "% " << endl;
175+
}
176+
177+
}; // End class MNIST
178+
} // End namespace examples
179+
180+
int main(int argc, char **argv) {
181+
examples::MNIST m;
182+
m.setup();
183+
m.train();
184+
m.test();
185+
186+
return 0;
187+
}
188+

src/nupic/algorithms/Connections.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ void Connections::raisePermanencesToThreshold(
472472
if( segmentThreshold == 0 ) //no synapses requested to be connected, done.
473473
return;
474474

475+
NTA_ASSERT(segment < segments_.size()) << "Accessing segment out of bounds.";
475476
auto &segData = segments_[segment];
476477
if( segData.numConnected >= segmentThreshold ) //the segment already satisfies the requirement, done.
477478
return;
@@ -496,7 +497,7 @@ void Connections::raisePermanencesToThreshold(
496497
// permance by such that it becomes a connected synapse.
497498
// After that there will be at least N synapses connected.
498499

499-
auto minPermSynPtr = synapses.begin() + threshold - 1;
500+
auto minPermSynPtr = synapses.begin() + threshold - 1; //threshold is ensured to be >=1 by condition at very beginning if(thresh == 0)...
500501
// Do a partial sort, it's faster than a full sort. Only minPermSynPtr is in
501502
// its final sorted position.
502503
const auto permanencesGreater = [&](const Synapse &A, const Synapse &B)
@@ -515,8 +516,9 @@ void Connections::raisePermanencesToThreshold(
515516

516517
void Connections::bumpSegment(const Segment segment, const Permanence delta) {
517518
const vector<Synapse> &synapses = synapsesForSegment(segment);
518-
for( const auto &syn : synapses )
519+
for( const auto &syn : synapses ) {
519520
updateSynapsePermanence(syn, synapses_[syn].permanence + delta);
521+
}
520522
}
521523

522524

0 commit comments

Comments
 (0)