Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed pickling on SP #644

Merged
merged 5 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions bindings/py/cpp_src/bindings/algorithms/py_SpatialPooler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
PyBind11 bindings for SpatialPooler class
*/

#include <tuple>
#include <iostream>

#include <bindings/suppress_register.hpp> //include before pybind11.h
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -224,24 +226,32 @@ Argument wrapAround boolean value that determines whether or not inputs
py_SpatialPooler.def("getSynPermMax", &SpatialPooler::getSynPermMax);
py_SpatialPooler.def("getMinPctOverlapDutyCycles", &SpatialPooler::getMinPctOverlapDutyCycles);
py_SpatialPooler.def("setMinPctOverlapDutyCycles", &SpatialPooler::setMinPctOverlapDutyCycles);

// loadFromString
py_SpatialPooler.def("loadFromString", [](SpatialPooler& self, const py::bytes& inString)

// saving and loading from file
py_SpatialPooler.def("saveToFile",
[](SpatialPooler &self, const std::string& filename) {self.saveToFile(filename,SerializableFormat::BINARY); });

py_SpatialPooler.def("loadFromFile",
[](SpatialPooler &self, const std::string& filename) { return self.loadFromFile(filename,SerializableFormat::BINARY); });


// loadFromString, loads SP from a JSON encoded string produced by writeToString().
py_SpatialPooler.def("loadFromString", [](SpatialPooler& self, const std::string& inString)
{
std::stringstream inStream(inString.cast<std::string>());
self.load(inStream);
std::stringstream inStream(inString);
self.load(inStream, JSON);
});

// writeToString
// writeToString, save SP to a JSON encoded string usable by loadFromString()
py_SpatialPooler.def("writeToString", [](const SpatialPooler& self)
{
std::ostringstream os;
os.flags(ios::scientific);
os.precision(numeric_limits<double>::digits10 + 1);
os.precision(std::numeric_limits<double>::digits10 + 1);
os.precision(std::numeric_limits<float>::digits10 + 1);

self.save(os);
self.save(os, JSON);

return py::bytes( os.str() );
return os.str();
});

// compute
Expand Down Expand Up @@ -402,24 +412,38 @@ Argument output An SDR representing the winning columns after


// pickle

py_SpatialPooler.def(py::pickle(
[](const SpatialPooler& sp)
[](const SpatialPooler& sp) // __getstate__
{
std::stringstream ss;

sp.save(ss);


/* The values in stringstream are binary so pickle will get confused
* trying to treat it as utf8 if you just return ss.str().
* So we must treat it as py::bytes. Some characters could be null values.
*/
return py::bytes( ss.str() );
},
[](py::bytes &s)
[](py::bytes &s) // __setstate__
{
/* pybind11 will pass in the bytes array without conversion.
* so we should be able to just create a string to initalize the stringstream.
*/
std::stringstream ss( s.cast<std::string>() );
SpatialPooler sp;
sp.load(ss);

std::unique_ptr<SpatialPooler> sp(new SpatialPooler());
sp->load(ss);

/*
* The __setstate__ part of the py::pickle() is actually a py::init() with some options.
* So the return value can be the object returned by value, by pointer,
* or by container (meaning a unique_ptr). SP has a problem with the copy constructor
* and pointers have problems knowing who the owner is so lets use unique_ptr.
* See: https://pybind11.readthedocs.io/en/stable/advanced/classes.html#custom-constructors
*/
return sp;
}));


}
} // namespace htm_ext
35 changes: 29 additions & 6 deletions bindings/py/cpp_src/bindings/algorithms/py_TemporalMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,32 @@ Argument anomalyMode (optional, default ANMode::RAW) selects mode for `TM.anomal
py::call_guard<py::scoped_ostream_redirect,
py::scoped_estream_redirect>());

// saving and loading from file
py_HTM.def("saveToFile",
[](TemporalMemory &self, const std::string& filename) {self.saveToFile(filename,SerializableFormat::BINARY); });

py_HTM.def("loadFromFile",
[](TemporalMemory &self, const std::string& filename) { return self.loadFromFile(filename,SerializableFormat::BINARY); });

// writeToString, save TM to a JSON encoded string usable by loadFromString()
py_HTM.def("writeToString", [](const TemporalMemory& self)
{
std::ostringstream os;
os.precision(std::numeric_limits<double>::digits10 + 1);
os.precision(std::numeric_limits<float>::digits10 + 1);

self.save(os, JSON);

return os.str();
});
// loadFromString, loads TM from a JSON encoded string produced by writeToString().
py_HTM.def("loadFromString", [](TemporalMemory& self, const std::string& inString)
{
std::stringstream inStream(inString);
self.load(inStream, JSON);
});


// pickle
// https://github.com/pybind/pybind11/issues/1061
py_HTM.def(py::pickle(
Expand All @@ -168,9 +194,6 @@ Argument anomalyMode (optional, default ANMode::RAW) selects mode for `TM.anomal
// __getstate__
std::ostringstream os;

os.flags(std::ios::scientific);
os.precision(std::numeric_limits<double>::digits10 + 1);

self.save(os);

return py::bytes(os.str());
Expand All @@ -185,10 +208,10 @@ Argument anomalyMode (optional, default ANMode::RAW) selects mode for `TM.anomal

std::stringstream is( str.cast<std::string>() );

HTM_t htm;
htm.load(is);
std::unique_ptr<TemporalMemory> tm(new TemporalMemory());
tm->load(is);

return htm;
return tm;
}
));

Expand Down
60 changes: 53 additions & 7 deletions bindings/py/tests/algorithms/spatial_pooler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import unittest
import pytest
import sys
import tempfile
import os

from htm.bindings.sdr import SDR
from htm.algorithms import SpatialPooler as SP
Expand Down Expand Up @@ -166,30 +168,74 @@ def testGetConnectedCountsUint32(self):

@pytest.mark.skipif(sys.version_info < (3, 6), reason="Fails for python2 with segmentation fault")
def testNupicSpatialPoolerPickling(self):
"""Test pickling / unpickling of NuPIC SpatialPooler."""
"""Test pickling / unpickling of HTM SpatialPooler."""
inputs = SDR( 100 ).randomize( .05 )
active = SDR( 100 )
sp = SP( inputs.dimensions, active.dimensions, stimulusThreshold = 1 )

# Simple test: make sure that dumping / loading works...
sp = SP()
pickledSp = pickle.dumps(sp)
for _ in range(10):
sp.compute( inputs, True, active )

if sys.version_info[0] >= 3:
proto = 3
else:
proto = 2

# Simple test: make sure that dumping / loading works...
pickledSp = pickle.dumps(sp, proto)
sp2 = pickle.loads(pickledSp)
self.assertEqual(str(sp), str(sp2), "Simple SpatialPooler pickle/unpickle failed.")


# or using File I/O
f = tempfile.TemporaryFile() # simulates opening a file ('wb')
pickle.dump(sp,f, proto)
f.seek(0)
sp3 = pickle.load(f)
#print(str(sp3))
f.close();
self.assertEqual(str(sp), str(sp3), "File I/O SpatialPooler pickle/unpickle failed.")

self.assertEqual(sp.getNumColumns(), sp2.getNumColumns(),
"Simple NuPIC SpatialPooler pickle/unpickle failed.")


def testNupicSpatialPoolerSavingToString(self):
"""Test writing to and reading from NuPIC SpatialPooler."""
inputs = SDR( 100 ).randomize( .05 )
active = SDR( 100 )
sp = SP( inputs.dimensions, active.dimensions, stimulusThreshold = 1 )

for _ in range(10):
sp.compute( inputs, True, active )

# Simple test: make sure that writing/reading works...
sp = SP()
s = sp.writeToString()

sp2 = SP(columnDimensions=[32, 32])
sp2.loadFromString(s)

self.assertEqual(sp.getNumColumns(), sp2.getNumColumns(),
"NuPIC SpatialPooler write to/read from string failed.")
self.assertEqual(str(sp), str(sp2),
"HTM SpatialPooler write to/read from string failed.")

def testSpatialPoolerSerialization(self):
# Test serializing with saveToFile() and loadFromFile()
inputs = SDR( 100 ).randomize( .05 )
active = SDR( 100 )
sp = SP( inputs.dimensions, active.dimensions, stimulusThreshold = 1 )

for _ in range(10):
sp.compute( inputs, True, active )

#print(str(sp))

# The SP now has some data in it, try serialization.
file = "spatial_pooler_test_save2.bin"
sp.saveToFile(file)
sp3 = SP()
sp3.loadFromFile(file)
self.assertEqual(str(sp), str(sp3), "HTM SpatialPooler serialization (using saveToFile/loadFromFile) failed.")
os.remove(file)


if __name__ == "__main__":
Expand Down
44 changes: 41 additions & 3 deletions bindings/py/tests/algorithms/temporal_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
import pickle
import sys
import os

from htm.bindings.sdr import SDR
from htm.algorithms import TemporalMemory as TM
Expand All @@ -38,15 +39,52 @@ def testCompute(self):
self.assertTrue( active.getSum() > 0 )


@pytest.mark.skipif(sys.version_info < (3, 6), reason="Fails for python2 with segmentation fault")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this still valid? (we don't test on py2, but someone still may run the repo on python2)

def testNupicTemporalMemoryPickling(self):
"""Test pickling / unpickling of NuPIC TemporalMemory."""

# Simple test: make sure that dumping / loading works...
tm = TM(columnDimensions=(16,))
inputs = SDR( 100 ).randomize( .05 )
tm = TM( inputs.dimensions)
for _ in range(10):
tm.compute( inputs, True)

pickledTm = pickle.dumps(tm)
pickledTm = pickle.dumps(tm, 2)
tm2 = pickle.loads(pickledTm)

self.assertEqual(tm.numberOfCells(), tm2.numberOfCells(),
"Simple NuPIC TemporalMemory pickle/unpickle failed.")


@pytest.mark.skip(reason="Fails with rapidjson internal assertion -- indicates a bad serialization")
def testNupicTemporalMemorySavingToString(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, this is a new test, that fails - we think exposing an internal JSON/pybind bug. It's fine to skip in that case.
Can you file an issue with pybind11? Is there a newer release we could try?

"""Test writing to and reading from TemporalMemory."""
inputs = SDR( 100 ).randomize( .05 )
tm = TM( inputs.dimensions)
for _ in range(10):
tm.compute( inputs, True)

# Simple test: make sure that writing/reading works...
s = tm.writeToString()

tm2 = TM()
tm2.loadFromString(s)

self.assertEqual(str(tm), str(tm),
"TemporalMemory write to/read from string failed.")

def testNupicTemporalMemorySerialization(self):
# Test serializing with each type of interface.
inputs = SDR( 100 ).randomize( .05 )
tm = TM( inputs.dimensions)
for _ in range(10):
tm.compute( inputs, True)

#print(str(tm))

# The TM now has some data in it, try serialization.
file = "temporalMemory_test_save2.bin"
tm.saveToFile(file)
tm3 = TM()
tm3.loadFromFile(file)
self.assertEqual(str(tm), str(tm3), "TemporalMemory serialization (using saveToFile/loadFromFile) failed.")
os.remove(file)