diff --git a/src/nupic/experimental/GridUniqueness.cpp b/src/nupic/experimental/GridUniqueness.cpp index aad820fd..e4e45be6 100644 --- a/src/nupic/experimental/GridUniqueness.cpp +++ b/src/nupic/experimental/GridUniqueness.cpp @@ -53,6 +53,68 @@ namespace bg = boost::geometry; BOOST_GEOMETRY_REGISTER_BOOST_TUPLE_CS(cs::cartesian) static std::atomic g_quitting(false); +static std::atomic g_computeBinSidelengthShouldContinue(true); + +static std::atomic g_computeUniqueHypercubeCounter(0); +static std::atomic g_computeBinsSideLengthCounter(0); +static std::atomic g_captureInterruptsCounter(0); + +static void (*g_prevHandler)(int) = nullptr; + + +// Custom interrupt processing is particularly necessary in Jupyter notebooks. +void processInterrupt(int sig) +{ + if (g_computeUniqueHypercubeCounter > 0) + { + g_quitting = true; + } + + if (g_computeBinsSideLengthCounter > 0) + { + g_computeBinSidelengthShouldContinue = false; + } +} + +class CaptureInterruptsRAII +{ +public: + CaptureInterruptsRAII() + { + if (g_captureInterruptsCounter++ == 0) + { + g_prevHandler = signal(SIGINT, processInterrupt); + } + } + + ~CaptureInterruptsRAII() + { + if (--g_captureInterruptsCounter == 0) + { + signal(SIGINT, g_prevHandler); + g_prevHandler = nullptr; + } + } +}; + +class IncrementRAII +{ +public: + IncrementRAII(std::atomic* counter) + : counter_(counter) + { + ++(*counter_); + } + + ~IncrementRAII() + { + --(*counter_); + } + +private: + std::atomic* counter_; +}; + template struct SquareMatrix2D { @@ -1345,14 +1407,13 @@ nupic::experimental::grid_uniqueness::computeGridUniquenessHypercube( { typedef std::chrono::steady_clock Clock; - // Manually handle interrupts so that they're handled when running in a - // Jupyter notebook, and to make the threads return cleanly. - struct sigaction sigIntHandler; - sigIntHandler.sa_handler = - [](int s) { g_quitting = true; }; - sigemptyset(&sigIntHandler.sa_mask); - sigIntHandler.sa_flags = 0; - sigaction(SIGINT, &sigIntHandler, NULL); + if (g_computeUniqueHypercubeCounter == 0) + { + // Recover from any previous interrupts. + g_quitting = false; + } + IncrementRAII incrementCounter(&g_computeUniqueHypercubeCounter); + CaptureInterruptsRAII captureInterrupts; NTA_CHECK(domainToPlaneByModule.size() == latticeBasisByModule.size()) << "The two arrays of matrices must be the same length (one per module) " @@ -1565,9 +1626,6 @@ nupic::experimental::grid_uniqueness::computeGridUniquenessHypercube( if (g_quitting) { - // The process might not be ending, the caller (e.g. the Python shell) is - // likely to catch this exception and continue, so prepare to run again. - g_quitting = false; NTA_THROW << "Caught interrupt signal"; } @@ -1747,12 +1805,12 @@ bool findGridCodeZero_noModulo( const vector& x0, const vector& dims, double readoutResolution, + std::atomic& shouldContinue, vector* pointWithGridCodeZero = nullptr) { // Avoid doing any allocations in each recursion. vector x0Copy(x0); vector dimsCopy(dims); - std::atomic shouldContinue(true); vector defaultPointBuffer; @@ -1790,27 +1848,29 @@ bool findGridCodeZero_noModulo( bool findGridCodeZeroAtRadius( double radius, const vector>>& domainToPlaneByModule, - double readoutResolution) + double readoutResolution, + std::atomic& shouldContinue) { const size_t numDims = domainToPlaneByModule[0][0].size(); for (size_t iDim = 0; iDim < numDims; ++iDim) { // Test the hyperplanes formed by setting this dimension to r and -r. - vector x0(numDims, -radius); vector dims(numDims, 2*radius); dims[iDim] = 0; if (findGridCodeZero_noModulo(domainToPlaneByModule, - x0, dims, readoutResolution)) + x0, dims, readoutResolution, + shouldContinue)) { return true; } x0[iDim] = radius; if (findGridCodeZero_noModulo(domainToPlaneByModule, - x0, dims, readoutResolution)) + x0, dims, readoutResolution, + shouldContinue)) { return true; } @@ -1826,15 +1886,24 @@ nupic::experimental::grid_uniqueness::computeBinSidelength( double resultPrecision, double upperBound) { + if (g_computeBinsSideLengthCounter == 0) + { + // Recover from any previous interrupts. + g_computeBinSidelengthShouldContinue = true; + } + IncrementRAII incrementCounter(&g_computeBinsSideLengthCounter); + CaptureInterruptsRAII captureInterrupts; + double tested = 0; double radius = 0.5; while (findGridCodeZeroAtRadius(radius, domainToPlaneByModule, - readoutResolution)) + readoutResolution, + g_computeBinSidelengthShouldContinue)) { tested = radius; - radius += 0.5; + radius *= 2; if (radius > upperBound) { @@ -1849,13 +1918,14 @@ nupic::experimental::grid_uniqueness::computeBinSidelength( double dec = (radius - tested) / 2; // The possible error is equal to dec*2. - while (dec*2 > resultPrecision2) + while (g_computeBinSidelengthShouldContinue && dec*2 > resultPrecision2) { const double testRadius = radius - dec; if (!findGridCodeZeroAtRadius(testRadius, domainToPlaneByModule, - readoutResolution)) + readoutResolution, + g_computeBinSidelengthShouldContinue)) { radius = testRadius; } @@ -1863,5 +1933,10 @@ nupic::experimental::grid_uniqueness::computeBinSidelength( dec /= 2; } + if (!g_computeBinSidelengthShouldContinue) + { + NTA_THROW << "Caught interrupt signal"; + } + return 2 * radius; }