From 318c6899031d024acb28a4937678b85f141f98d8 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Tue, 5 Jun 2018 22:57:51 -0700 Subject: [PATCH] fix shared_storage free (#11159) * fix shared_storage free * fix bracket * make local ref * cpplint * fix tests * fix tests --- python/mxnet/gluon/data/dataloader.py | 2 ++ src/storage/cpu_shared_storage_manager.h | 10 ++++++++-- tests/python/unittest/test_gluon_data.py | 18 ++++++++++++++++++ tests/python/unittest/test_ndarray.py | 1 - 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 151b49d457aa..29b9b81aca04 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -57,6 +57,8 @@ def rebuild_ndarray(pid, fd, shape, dtype): def reduce_ndarray(data): """Reduce ndarray to shared memory handle""" + # keep a local ref before duplicating fd + data = data.as_in_context(context.Context('cpu_shared', 0)) pid, fd, shape, dtype = data._to_shared_mem() if sys.version_info[0] == 2: fd = multiprocessing.reduction.reduce_handle(fd) diff --git a/src/storage/cpu_shared_storage_manager.h b/src/storage/cpu_shared_storage_manager.h index 85c6a352afdb..a52d779d2318 100644 --- a/src/storage/cpu_shared_storage_manager.h +++ b/src/storage/cpu_shared_storage_manager.h @@ -174,8 +174,12 @@ void CPUSharedStorageManager::Alloc(Storage::Handle* handle) { } if (fid == -1) { - LOG(FATAL) << "Failed to open shared memory. shm_open failed with error " - << strerror(errno); + if (is_new) { + LOG(FATAL) << "Failed to open shared memory. shm_open failed with error " + << strerror(errno); + } else { + LOG(FATAL) << "Invalid file descriptor from shared array."; + } } if (is_new) CHECK_EQ(ftruncate(fid, size), 0); @@ -216,9 +220,11 @@ void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) { << strerror(errno); #ifdef __linux__ + if (handle.shared_id != -1) { CHECK_EQ(close(handle.shared_id), 0) << "Failed to close shared memory. close failed with error " << strerror(errno); + } #else if (count == 0) { auto filename = SharedHandleToString(handle.shared_pid, handle.shared_id); diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 93160aa0940c..751886b8e7f2 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -140,6 +140,16 @@ def __getitem__(self, idx): def __len__(self): return 50 + def batchify_list(self, data): + """ + return list of ndarray without stack/concat/pad + """ + if isinstance(data, (tuple, list)): + return list(data) + if isinstance(data, mx.nd.NDArray): + return [data] + return data + def batchify(self, data): """ Collate data into batch. Use shared memory for stacking. @@ -194,6 +204,14 @@ def batchify(self, data): print(data) print('{}:{}'.format(epoch, i)) + data = Dummy(True) + loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2) + for epoch in range(1): + for i, data in enumerate(loader): + if i % 100 == 0: + print(data) + print('{}:{}'.format(epoch, i)) + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 496f80f927f6..a0604658ee14 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1304,7 +1304,6 @@ def test_norm(ctx=default_context()): assert arr1.shape == arr2.shape mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy()) - if __name__ == '__main__': import nose nose.runmodule()