Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RDSE: serialization for py bindings #608

Merged
merged 11 commits into from
Sep 4, 2019
34 changes: 34 additions & 0 deletions bindings/py/cpp_src/bindings/encoders/py_RDSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

#include <bindings/suppress_register.hpp> //include before pybind11.h
#include <pybind11/pybind11.h>
#include <pybind11/iostream.h>

#include <htm/encoders/RandomDistributedScalarEncoder.hpp>

namespace py = pybind11;

using namespace htm;
using namespace std;

namespace htm_ext
{
Expand Down Expand Up @@ -113,5 +115,37 @@ fields are filled in automatically.)");
self.encode(value, *sdr);
return sdr;
});


// Serialization
// loadFromString
py_RDSE.def("loadFromString", [](RDSE& self, const py::bytes& inString) {
std::stringstream inStream(inString.cast<std::string>());
self.load(inStream);
breznak marked this conversation as resolved.
Show resolved Hide resolved
});

// writeToString
py_RDSE.def("writeToString", [](const RDSE& self) {
std::ostringstream os;
os.flags(ios::scientific);
os.precision(numeric_limits<double>::digits10 + 1);
self.save(os);
return py::bytes( os.str() );
breznak marked this conversation as resolved.
Show resolved Hide resolved
});

// pickle
py_RDSE.def(py::pickle(
[](const RDSE& self) {
std::stringstream ss;
self.save(ss);
return py::bytes( ss.str() );
},
[](py::bytes &s) {
std::stringstream ss( s.cast<std::string>() );
RDSE self;
self.load(ss);
return self;
breznak marked this conversation as resolved.
Show resolved Hide resolved
}));

}
}
breznak marked this conversation as resolved.
Show resolved Hide resolved
44 changes: 41 additions & 3 deletions bindings/py/tests/encoders/rdse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,44 @@ def testSeed(self):
B = R.encode( 987654 )
assert( A != B )

@unittest.skip(reason="Known issue: https://github.com/htm-community/htm.core/issues/160")
def testPickle(self):
assert(False) # TODO: Unimplemented

def testPickle(self):
"""
The pickling is successfull if pickle serializes and de-serialize the
RDSE object.
Moreover, the de-serialized object shall give the same SDR than the
original encoder given the same scalar value to encode.
"""
rdse_params = RDSE_Parameters()
rdse_params.sparsity = 0.1
rdse_params.size = 100
rdse_params.resolution = 0.1
rdse_params.seed = 1997

rdse = RDSE(rdse_params)
filename = "RDSE_testPickle"

try:
with open(filename, "wb") as f:
pickle.dump(rdse, f)
except:
dump_success = False
else:
dump_success = True

assert(dump_success)

try:
with open(filename, "rb") as f:
rdse_loaded = pickle.load(f)
except:
read_success = False
else:
read_success = True

assert(read_success)
value_to_encode = 69003
SDR_original = rdse.encode(value_to_encode)
SDR_loaded = rdse_loaded.encode(value_to_encode)

assert(SDR_original == SDR_loaded)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a test for saveToFile() and loadFromFile()
See end of spatial_pooler_test.py in branch sp_save.