Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
1) Change name to make solution more general:
Browse files Browse the repository at this point in the history
dist_sync_mpi -> dist_sync_allreduce
mpi_collectives -> collectives
MPI_Wrapper -> COLL_Wrapper

2) Add Test for dist_sync_allreduce

3) Fix style issue reported from lint
  • Loading branch information
zhouhaiy committed May 15, 2018
1 parent 4c0b842 commit fcc0338
Show file tree
Hide file tree
Showing 13 changed files with 323 additions and 245 deletions.
24 changes: 12 additions & 12 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -342,16 +342,16 @@ endif
# for kvstore with type dist_sync_mpi
PROTOBUF_DIR=$(ROOTDIR)/deps
PROTOC=$(PROTOBUF_DIR)/bin/protoc
MPI_COLL_PATH=$(ROOTDIR)/src/kvstore/mpi_collectives
PROTO_GEN_FILE=src/kvstore/mpi_collectives/src/mpi_message.pb.cc
COLL_PATH=$(ROOTDIR)/src/kvstore/collectives
PROTO_GEN_FILE=src/kvstore/collectives/src/mpi_message.pb.cc
DEF_MPI_PATH=$(ROOTDIR)/3rdparty/mpich
ifeq ($(USE_DIST_KVSTORE), 1)
ifeq ($(USE_MPI_DIST_KVSTORE), 1)
ifeq ($(USE_ALLREDUCE_DIST_KVSTORE), 1)
ifeq ($(MPI_ROOT),)
# Default mpi
MPI_ROOT := $(shell ./prepare_mpi.sh $(DEF_MPI_PATH))
endif
CFLAGS += -DMXNET_USE_MPI_DIST_KVSTORE=1 -I$(MPI_ROOT)/include -I$(PROTOBUF_DIR)/include -I$(MPI_COLL_PATH)/include -I$(MPI_COLL_PATH)/src
CFLAGS += -DMXNET_USE_ALLREDUCE_DIST_KVSTORE=1 -I$(MPI_ROOT)/include -I$(PROTOBUF_DIR)/include -I$(COLL_PATH)/include -I$(COLL_PATH)/src
LDFLAGS += -L$(MPI_ROOT)/lib -Wl,-rpath=$(MPI_ROOT)/lib -lmpi
LDFLAGS += -L$(PROTOBUF_DIR)/lib -Wl,-rpath=$(PROTOBUF_DIR)/lib -lprotobuf
endif
Expand All @@ -362,11 +362,11 @@ endif

all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages

MPI_SRC = $(wildcard src/kvstore/mpi_collectives/src/*.cc)
MPI_SRC += $(PROTO_GEN_FILE)
MPI_OBJ = $(patsubst %.cc, build/%.o, $(MPI_SRC))
ALLREDUCE_SRC = $(wildcard src/kvstore/collectives/src/*.cc)
ALLREDUCE_SRC += $(PROTO_GEN_FILE)
ALLREDUCE_OBJ = $(patsubst %.cc, build/%.o, $(ALLREDUCE_SRC))

SRC_FILTER = $(MPI_SRC)
SRC_FILTER = $(ALLREDUCE_SRC)
ORIGSRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
SRC = $(filter-out $(SRC_FILTER), $(ORIGSRC))
OBJ = $(patsubst %.cc, build/%.o, $(SRC))
Expand Down Expand Up @@ -450,8 +450,8 @@ else
endif

ifeq ($(USE_DIST_KVSTORE), 1)
ifeq ($(USE_MPI_DIST_KVSTORE), 1)
ALL_DEP += $(MPI_OBJ)
ifeq ($(USE_ALLREDUCE_DIST_KVSTORE), 1)
ALL_DEP += $(ALLREDUCE_OBJ)
endif
endif

Expand Down Expand Up @@ -487,7 +487,7 @@ build/plugin/%.o: plugin/%.cc
@mkdir -p $(@D)
$(CXX) -std=c++11 -c $(CFLAGS) -MMD -Isrc/operator -c $< -o $@

build/src/kvstore/mpi_collectives/src/%.o: $(MPI_COLL_PATH)/src/%.cc $(PROTO_GEN_FILE)
build/src/kvstore/collectives/src/%.o: $(COLL_PATH)/src/%.cc $(PROTO_GEN_FILE)
@mkdir -p $(@D)
$(CXX) -std=c++11 -c $(CFLAGS) -MMD -c $< -o $@

Expand Down Expand Up @@ -515,7 +515,7 @@ PSLITE:
$(MAKE) CXX=$(CXX) DEPS_PATH=$(DEPS_PATH) -C $(PS_PATH) ps

$(PROTO_GEN_FILE): PSLITE
$(PROTOC) --cpp_out=$(MPI_COLL_PATH)/src --proto_path=$(MPI_COLL_PATH)/src $(MPI_COLL_PATH)/src/mpi_message.proto
$(PROTOC) --cpp_out=$(COLL_PATH)/src --proto_path=$(COLL_PATH)/src $(COLL_PATH)/src/mpi_message.proto

$(DMLC_CORE)/libdmlc.a: DMLCCORE

Expand Down
4 changes: 2 additions & 2 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ USE_F16C =
# whether or not to enable multi-machine supporting
USE_DIST_KVSTORE = 0

# whether or not to enable kvstore with type dist_sync_mpi
USE_MPI_DIST_KVSTORE = 0
# whether or not to enable kvstore with type dist_sync_allreduce
USE_ALLREDUCE_DIST_KVSTORE = 0

# mpi library root directory, mpi_collectives will depend
# upon $(MPI_ROOT)/include $(MPI_ROOT)/lib, user need to
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _init_kvstore(self):
# optimizer preferably needs to be set before init for multiprecision
for i, param in enumerate(self._params):
param_arrays = param.list_data()
if 'mpi' not in kvstore.type:
if 'allreduce' not in kvstore.type:
kvstore.init(i, param_arrays[0])
kvstore.pull(i, param_arrays, priority=-i)
else:
Expand Down Expand Up @@ -191,7 +191,7 @@ def step(self, batch_size, ignore_stale_grad=False):
%(param.name, str(data.context)))

if self._kvstore:
if 'mpi' not in self._kvstore.type:
if 'allreduce' not in self._kvstore.type:
self._kvstore.push(i, param.list_grad(), priority=-i)
if self._update_on_kvstore:
self._kvstore.pull(i, param.list_data(), priority=-i)
Expand Down
32 changes: 16 additions & 16 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def push(self, key, value, priority=0):
There is no synchronization between workers.
One can use ``_barrier()`` to sync all workers.
Note: This api is not supported for kvstore with type dist_sync_mpi.
Note: This api is not supported for kvstore with type dist_sync_allreduce.
Use :py:meth:`pushpull` instead.
Parameters
Expand Down Expand Up @@ -229,7 +229,7 @@ def push(self, key, value, priority=0):
>>> print b
<RowSparseNDArray 2x3 @cpu(0)>
"""
if self.type != 'dist_sync_mpi':
if self.type != 'dist_sync_allreduce':
ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
if use_str_keys:
check_call(_LIB.MXKVStorePushEx(
Expand All @@ -256,7 +256,7 @@ def pull(self, key, out=None, priority=0):
For `RowSparseNDArray` values, this call is ignored,
please use ``row_sparse_pull`` instead.
Note: This api is not supported for kvstore with type dist_sync_mpi.
Note: This api is not supported for kvstore with type dist_sync_allreduce.
Use :py:meth:`pushpull` instead.
Parameters
Expand Down Expand Up @@ -305,7 +305,7 @@ def pull(self, key, out=None, priority=0):
[ 2. 2. 2.]]
"""
assert(out is not None)
if self.type != 'dist_sync_mpi':
if self.type != 'dist_sync_allreduce':
ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
if use_str_keys:
check_call(_LIB.MXKVStorePullEx(
Expand All @@ -324,7 +324,7 @@ def pushpull(self, key, ins, outs, priority=0):
thread. The rank 0 node will collect allreduce request info from all nodes and ensure
every all reduce execution order is the same across all nodes.
Note: This api is only supported for kvstore with type dist_sync_mpi
Note: This api is only supported for kvstore with type dist_sync_allreduce
Parameters
----------
Expand Down Expand Up @@ -360,7 +360,7 @@ def pushpull(self, key, ins, outs, priority=0):
[[ 2. 2. 2.]
[ 2. 2. 2.]]
"""
if self.type == 'dist_sync_mpi':
if self.type == 'dist_sync_allreduce':
ckeys, cinvals, use_str_keys = _ctype_key_value(key, ins)
ckeys, coutvals, use_str_keys = _ctype_key_value(key, outs)
if use_str_keys:
Expand All @@ -381,7 +381,7 @@ def broadcast(self, key, values, root_rank, priority=0):
This function returns immediately after sending an broadcast request to mpi background
thread. In mpi background thread, it will invoke MPI_Bcast in every node.
Note: This api is only supported for kvstore with type dist_sync_mpi
Note: This api is only supported for kvstore with type dist_sync_allreduce
Parameters
----------
Expand All @@ -408,7 +408,7 @@ def broadcast(self, key, values, root_rank, priority=0):
>>> [[ 2. 2. 2.]
[ 2. 2. 2.]]
"""
if self.type == 'dist_sync_mpi':
if self.type == 'dist_sync_allreduce':
ckeys, cinvals, use_str_keys = _ctype_key_value(key, values)
if use_str_keys:
check_call(_LIB.MXKVStoreBroadcastEx(
Expand All @@ -432,7 +432,7 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
The returned values are guaranteed to be the latest values in the store.
Note: This api is not supported for kvstore with type dist_sync_mpi
Note: This api is not supported for kvstore with type dist_sync_allreduce
Parameters
----------
Expand Down Expand Up @@ -477,7 +477,7 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
"""
assert(out is not None)
assert(row_ids is not None)
if self.type == 'dist_sync_mpi':
if self.type == 'dist_sync_allreduce':
raise Exception("This api is not supported for kvstore with type %s"%self.type)
if isinstance(row_ids, NDArray):
row_ids = [row_ids]
Expand Down Expand Up @@ -553,7 +553,7 @@ def set_gradient_compression(self, compression_params):
Other keys in this dictionary are optional and specific to the type
of gradient compression.
"""
if ('device' in self.type) or ('dist' in self.type) and ('mpi' not in self.type): # pylint: disable=unsupported-membership-test
if ('device' in self.type) or ('dist' in self.type) and ('allreduce' not in self.type): # pylint: disable=unsupported-membership-test
ckeys, cvals = _ctype_dict(compression_params)
check_call(_LIB.MXKVStoreSetGradientCompression(self.handle,
mx_uint(len(compression_params)),
Expand All @@ -568,7 +568,7 @@ def set_optimizer(self, optimizer):
If using multiple machines and this operation is invoked from a worker node,
it will serialized the optimizer with pickle and send it to all servers.
The function returns after all servers have been updated.
In kvstore with dist_sync_mpi, this api only updates the local optimizer
In kvstore with dist_sync_allreduce, this api only updates the local optimizer
same as single machine.
Parameters
Expand Down Expand Up @@ -745,7 +745,7 @@ def _send_command_to_servers(self, head, body):
body : str
the body of the command.
"""
if self.type == 'dist_sync_mpi':
if self.type == 'dist_sync_allreduce':
raise Exception("This api is not supported for kvstore with type %s"%self.type)
else:
check_call(_LIB.MXKVStoreSendCommmandToServers(
Expand Down Expand Up @@ -777,13 +777,13 @@ def create(name='local'):
No two updates happen on the same weight at the same time. However, the order is not
guaranteed.
``dist_sync_mpi``: Behaves similarly to dist_sync but with some major difference.
With ``dist_sync_mpi``, no parameter server configured, replace push and pull apis with
``dist_sync_allreduce``: Behaves similarly to dist_sync but with some major difference.
With ``dist_sync_allreduce``, no parameter server configured, replace push and pull apis with
pushpull.
Parameters
----------
name : {'local', 'device', 'nccl', 'dist_sync', 'dist_device_sync', 'dist_async'}
name : {'local', 'device', 'nccl', 'dist_sync', 'dist_device_sync', 'dist_async', 'dist_sync_allreduce'}
The type of KVStore.
Returns
-------
Expand Down
19 changes: 8 additions & 11 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _create_kvstore(kvstore, num_device, arg_params):
kv = None
elif isinstance(kvstore, kvs.KVStore):
kv = kvstore
if kv.type == 'dist_sync_mpi':
if kv.type == 'dist_sync_allreduce':
update_on_kvstore = False
elif isinstance(kvstore, str):
# create kvstore using the string type
Expand All @@ -88,7 +88,7 @@ def _create_kvstore(kvstore, num_device, arg_params):
arg_params.values())
if max_size > 1024 * 1024 * 16:
update_on_kvstore = False
if kvstore == 'dist_sync_mpi':
if kvstore == 'dist_sync_allreduce':
update_on_kvstore = False
else:
raise TypeError('kvstore must be KVStore, str or None')
Expand All @@ -102,7 +102,7 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_o
"""Initialize kvstore"""
for idx, param_on_devs in enumerate(param_arrays):
name = param_names[idx]
if 'mpi' not in kvstore.type:
if 'allreduce' not in kvstore.type:
kvstore.init(name, arg_params[name])
else:
kvstore.broadcast(name, param_on_devs, 0, priority=-idx)
Expand Down Expand Up @@ -137,13 +137,10 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
if grad_list[0] is None:
continue
name = param_names[index]
if 'mpi' not in kvstore.type:
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
# pull back the weights
kvstore.pull(name, arg_list, priority=-index)
else:
kvstore.pushpull(name, grad_list, grad_list, priority=-index)
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
# pull back the weights
kvstore.pull(name, arg_list, priority=-index)

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None, param_names=None):
Expand All @@ -155,7 +152,7 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
index = i
if kvstore:
name = param_names[index]
if 'mpi' not in kvstore.type:
if 'allreduce' not in kvstore.type:
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
# pull back the sum gradients, to the same locations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@
* Copyright (c) 2018 by Contributors
*/

#ifndef MXNET_MPI_COLLECTIVES_INCLUDE_MPI_UTIL_H_
#define MXNET_MPI_COLLECTIVES_INCLUDE_MPI_UTIL_H_
#ifndef MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_UTIL_H_
#define MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_UTIL_H_

#if MXNET_USE_MPI_DIST_KVSTORE
#if MXNET_USE_ALLREDUCE_DIST_KVSTORE


#include <stdio.h>
#include <vector>

#define MPI_UTIL_DEBUG_ON 0
#define COLL_UTIL_DEBUG_ON 0

#if MPI_UTIL_DEBUG_ON
#define MXMPI_DEBUG(rank, fmt, args...) printf("rank[%d]:" fmt, rank, ## args)
#if COLL_UTIL_DEBUG_ON
#define MXCOLL_DEBUG(rank, fmt, args...) printf("rank[%d]:" fmt, rank, ## args)
#else
#define MXMPI_DEBUG(fmt, args...)
#define MXCOLL_DEBUG(fmt, args...)
#endif

/****************************************************
Expand Down Expand Up @@ -65,4 +65,4 @@ size_t countIDX(const std::vector<T> &vec,
}

#endif
#endif // MXNET_MPI_COLLECTIVES_INCLUDE_MPI_UTIL_H_
#endif // MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_UTIL_H_
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
* Copyright (c) 2018 by Contributors
*/

#ifndef MXNET_MPI_COLLECTIVES_INCLUDE_MPI_WRAPPER_H_
#define MXNET_MPI_COLLECTIVES_INCLUDE_MPI_WRAPPER_H_
#ifndef MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_WRAPPER_H_
#define MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_WRAPPER_H_

#if MXNET_USE_MPI_DIST_KVSTORE
#if MXNET_USE_ALLREDUCE_DIST_KVSTORE

#include <mpi.h>

Expand All @@ -51,7 +51,7 @@ MPI_Datatype MPI_Data_Type_Cast<double>(void) {
}

template <class xpu, class DType>
struct MPI_Wrapper {
struct COLL_Wrapper {
static int Broadcast(mxnet::NDArray *input_array,
int root_rank) {
return 0; }
Expand All @@ -63,7 +63,7 @@ struct MPI_Wrapper {

// CPU Implementation
template <class DType>
struct MPI_Wrapper<mxnet::cpu, DType> {
struct COLL_Wrapper<mxnet::cpu, DType> {
static int Broadcast(mxnet::NDArray *input_array,
int root_rank) {
DType *buf = reinterpret_cast<DType *>(input_array->data().dptr<DType>());
Expand Down Expand Up @@ -94,21 +94,21 @@ struct MPI_Wrapper<mxnet::cpu, DType> {

// GPU Implementation
template <class DType>
struct MPI_Wrapper<mxnet::gpu, DType> {
struct COLL_Wrapper<mxnet::gpu, DType> {
static int Broadcast(mxnet::NDArray *input_array,
int root_rank) {
// TODO(zhouhaiy): implement gpu broadcast
LOG(FATAL) << "MPI For GPU version has not been implemented.";
LOG(FATAL) << "Collective For GPU version has not been implemented.";
return -1;
}

static int AllReduce(mxnet::NDArray *input_array,
mxnet::NDArray *output_array) {
// TODO(zhouhaiy): implement gpu all reduce
LOG(FATAL) << "MPI For GPU version has not been implemented.";
LOG(FATAL) << "Collective For GPU version has not been implemented.";
return -1;
}
};

#endif
#endif // MXNET_MPI_COLLECTIVES_INCLUDE_MPI_WRAPPER_H_
#endif // MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_WRAPPER_H_
Loading

0 comments on commit fcc0338

Please sign in to comment.