Skip to content

Commit

Permalink
fix shared_storage free (apache#11159)
Browse files Browse the repository at this point in the history
* fix shared_storage free

* fix bracket

* make local ref

* cpplint

* fix tests

* fix tests
  • Loading branch information
zhreshold authored and piiswrong committed Jun 6, 2018
1 parent 24804e8 commit 318c689
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def rebuild_ndarray(pid, fd, shape, dtype):

def reduce_ndarray(data):
"""Reduce ndarray to shared memory handle"""
# keep a local ref before duplicating fd
data = data.as_in_context(context.Context('cpu_shared', 0))
pid, fd, shape, dtype = data._to_shared_mem()
if sys.version_info[0] == 2:
fd = multiprocessing.reduction.reduce_handle(fd)
Expand Down
10 changes: 8 additions & 2 deletions src/storage/cpu_shared_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,12 @@ void CPUSharedStorageManager::Alloc(Storage::Handle* handle) {
}

if (fid == -1) {
LOG(FATAL) << "Failed to open shared memory. shm_open failed with error "
<< strerror(errno);
if (is_new) {
LOG(FATAL) << "Failed to open shared memory. shm_open failed with error "
<< strerror(errno);
} else {
LOG(FATAL) << "Invalid file descriptor from shared array.";
}
}

if (is_new) CHECK_EQ(ftruncate(fid, size), 0);
Expand Down Expand Up @@ -216,9 +220,11 @@ void CPUSharedStorageManager::FreeImpl(const Storage::Handle& handle) {
<< strerror(errno);

#ifdef __linux__
if (handle.shared_id != -1) {
CHECK_EQ(close(handle.shared_id), 0)
<< "Failed to close shared memory. close failed with error "
<< strerror(errno);
}
#else
if (count == 0) {
auto filename = SharedHandleToString(handle.shared_pid, handle.shared_id);
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ def __getitem__(self, idx):
def __len__(self):
return 50

def batchify_list(self, data):
"""
return list of ndarray without stack/concat/pad
"""
if isinstance(data, (tuple, list)):
return list(data)
if isinstance(data, mx.nd.NDArray):
return [data]
return data

def batchify(self, data):
"""
Collate data into batch. Use shared memory for stacking.
Expand Down Expand Up @@ -194,6 +204,14 @@ def batchify(self, data):
print(data)
print('{}:{}'.format(epoch, i))

data = Dummy(True)
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
if i % 100 == 0:
print(data)
print('{}:{}'.format(epoch, i))

if __name__ == '__main__':
import nose
nose.runmodule()
1 change: 0 additions & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,6 @@ def test_norm(ctx=default_context()):
assert arr1.shape == arr2.shape
mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy())


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 318c689

Please sign in to comment.