Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit fe07d50

Browse files
ctcyangeric-haibin-lin
authored andcommitted
[MXNET-331] Single machine All Reduce Topology-aware Communication (Updated) (#11591)
* add multiroot all-reduce communication pattern * fix bug with UpdateWeight * fix PCI-E links appearing in weight matrix bug * optimization to skip CopyFromTo in ReduceInner gains a bit of throughput * remove unnecessary if statement * Add tests * add more tests, 6 tests left to add * get rid of some dead code * Add comments * Add randomized tests for backtrack and kernighan-lin * Fix Postprocess * Add switch for first valid tree when num_gpus > 8, and for maximum weight when num_gpus <= 8 * Kernighan-Lin seems to find better trees * get rid of printfs * change defaults * inherit from CommDevice instead of Comm * Fix lint errors * Add Python test using MXNET_KVSTORE_USETREE, fix CMake compilation problem, add header guard * fix lint errors * better header guard that works for tests * get rid of unused variable warning * retrigger jenkins * resolve 2 comments * address comment using Class to do test, get rid of extraneous test, use PCI-E as fallback for GPUs that are not linked by NVLink * address comments * fix a few bugs * get rid of printfs * get rid of print * Comment out test for now * fix 2 more bugs * fix segfault * change PrintVector, PrintTopo, PrintMatrix to LOG(INFO) instead of stdout * Fix code alignment * get rid of todo * Make changes to env variable names to indicate they are TREE-related * Add note saying when ARRAY_BOUND env var takes effect
1 parent 64d2e8b commit fe07d50

File tree

9 files changed

+2538
-77
lines changed

9 files changed

+2538
-77
lines changed

docs/faq/env_var.md

+26
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,32 @@ export MXNET_GPU_WORKER_NTHREADS=3
8383
- The minimum size of a "big array".
8484
- When the array size is bigger than this threshold, MXNET_KVSTORE_REDUCTION_NTHREADS threads are used for reduction.
8585
- This parameter is also used as a load balancer in kvstore. It controls when to partition a single weight to all the servers. If the size of a single weight is less than MXNET_KVSTORE_BIGARRAY_BOUND then, it is sent to a single randomly picked server otherwise it is partitioned to all the servers.
86+
87+
* MXNET_KVSTORE_USETREE
88+
- Values: 0(false) or 1(true) ```(default=0)```
89+
- If true, MXNet tries to use tree reduction for Push and Pull communication.
90+
- Otherwise, MXNet uses the default Push and Pull implementation.
91+
- [Tree reduction technology](http://www.sysml.cc/doc/178.pdf) has been shown to be faster than the standard ```--kv-store device``` Push/Pull and ```--kv-store nccl``` Push/Pull for small batch sizes.
92+
93+
* MXNET_KVSTORE_LOGTREE
94+
- Values: 0(false) or 1(true) ```(default=0)```
95+
- If true and MXNET_KVSTORE_USETREE is set to 1, MXNet will log the reduction trees that have been generated.
96+
97+
* MXNET_KVSTORE_TREE_ARRAY_BOUND
98+
- Values: Int ```(default=10000000)```
99+
- The minimum size of a "big array".
100+
- When the array size is bigger than this threshold and MXNET_KVSTORE_USETREE is set to 1, multiple trees are used to load balance the big gradient being communicated in order to better saturate link bandwidth.
101+
- Note: This environmental variable only takes effect if Tree KVStore is being used (MXNET_KVSTORE_USETREE=1).
102+
103+
* MXNET_KVSTORE_TREE_BACKTRACK
104+
- Values: 0(false) or 1(true) ```(default=0)
105+
- If true and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use backtracking to generate the trees required for tree reduction.
106+
- If false and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use Kernighan-Lin heuristic to generate the trees required for tree reduction.
107+
108+
* MXNET_KVSTORE_TREE_LINK_USAGE_PENALTY
109+
- Values: Float ```(default=0.7)```
110+
- The multiplicative penalty term to a link being used once.
111+
86112
* MXNET_ENABLE_GPU_P2P
87113
- Values: 0(false) or 1(true) ```(default=1)```
88114
- If true, MXNet tries to use GPU peer-to-peer communication, if available on your device,

src/kvstore/comm.h

+68-53
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,31 @@ class CommDevice : public Comm {
474474
}
475475
}
476476

477+
const NDArray& ReduceRowSparse(int key, const std::vector<NDArray>& src,
478+
int priority) {
479+
auto& buf = merge_buf_[key];
480+
std::vector<NDArray> reduce(src.size());
481+
482+
const NDArrayStorageType stype = src[0].storage_type();
483+
NDArray& buf_merged = buf.merged_buf(stype);
484+
if (buf.copy_buf.empty()) {
485+
// initialize buffer for copying during reduce
486+
buf.copy_buf.resize(src.size());
487+
for (size_t j = 0; j < src.size(); ++j) {
488+
buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype());
489+
}
490+
}
491+
CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type())
492+
<< "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. "
493+
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
494+
for (size_t i = 0; i < src.size(); ++i) {
495+
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
496+
reduce[i] = buf.copy_buf[i];
497+
}
498+
ElementwiseSum(reduce, &buf_merged, priority);
499+
return buf_merged;
500+
}
501+
477502
const NDArray& Reduce(int key, const std::vector<NDArray>& src,
478503
int priority) override {
479504
// when this reduce is called from kvstore_dist, gc is not set
@@ -490,13 +515,14 @@ class CommDevice : public Comm {
490515

491516
InitBuffersAndComm(src);
492517
auto& buf = merge_buf_[key];
493-
std::vector<NDArray> reduce(src.size());
494518

495519
const NDArrayStorageType stype = src[0].storage_type();
496520
NDArray& buf_merged = buf.merged_buf(stype);
497521
// normal dense reduce
498522
if (stype == kDefaultStorage) {
499523
CopyFromTo(src[0], &buf_merged, priority);
524+
525+
std::vector<NDArray> reduce(src.size());
500526
reduce[0] = buf_merged;
501527

502528
if (buf.copy_buf.empty()) {
@@ -514,24 +540,11 @@ class CommDevice : public Comm {
514540
CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
515541
reduce[i+1] = buf.copy_buf[i];
516542
}
543+
ElementwiseSum(reduce, &buf_merged, priority);
517544
} else {
518545
// sparse reduce
519-
if (buf.copy_buf.empty()) {
520-
// initialize buffer for copying during reduce
521-
buf.copy_buf.resize(src.size());
522-
for (size_t j = 0; j < src.size(); ++j) {
523-
buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype());
524-
}
525-
}
526-
CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type())
527-
<< "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. "
528-
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
529-
for (size_t i = 0; i < src.size(); ++i) {
530-
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
531-
reduce[i] = buf.copy_buf[i];
532-
}
546+
buf_merged = ReduceRowSparse(key, src, priority);
533547
}
534-
ElementwiseSum(reduce, &buf_merged, priority);
535548
return buf_merged;
536549
}
537550

@@ -659,6 +672,42 @@ class CommDevice : public Comm {
659672
}
660673
}
661674

675+
using KeyAttrs = std::tuple<int, TShape, int>;
676+
// try to allocate buff on device evenly
677+
void InitMergeBuffer(const std::vector<Context>& devs) {
678+
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
679+
const KeyAttrs& a, const KeyAttrs& b) {
680+
return std::get<1>(a).Size() > std::get<1>(b).Size();
681+
});
682+
683+
std::unordered_map<int, std::pair<Context, size_t>> ctx_info;
684+
for (auto d : devs) {
685+
ctx_info[d.dev_id] = std::make_pair(d, 0);
686+
}
687+
688+
for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
689+
const int key = std::get<0>(sorted_key_attrs_[i]);
690+
const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
691+
const int type = std::get<2>(sorted_key_attrs_[i]);
692+
auto& buf = merge_buf_[key];
693+
Context ctx;
694+
size_t min_size = std::numeric_limits<size_t>::max();
695+
for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) {
696+
size_t size = it->second.second;
697+
if (size <= min_size) {
698+
ctx = it->second.first;
699+
min_size = size;
700+
}
701+
}
702+
// Delayed allocation - as the dense merged buffer might not be used at all if push()
703+
// only sees sparse arrays
704+
bool delay_alloc = true;
705+
buf.merged = NDArray(shape, ctx, delay_alloc, type);
706+
ctx_info[ctx.dev_id].second += shape.Size();
707+
}
708+
inited_ = true;
709+
}
710+
662711
private:
663712
void EnableP2P(const std::vector<Context>& devs) {
664713
#if MXNET_USE_CUDA
@@ -702,43 +751,6 @@ class CommDevice : public Comm {
702751
#endif
703752
}
704753

705-
using KeyAttrs = std::tuple<int, TShape, int>;
706-
// try to allocate buff on device evenly
707-
void InitMergeBuffer(const std::vector<Context>& devs) {
708-
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
709-
const KeyAttrs& a, const KeyAttrs& b) {
710-
return std::get<1>(a).Size() > std::get<1>(b).Size();
711-
});
712-
713-
std::unordered_map<int, std::pair<Context, size_t>> ctx_info;
714-
for (auto d : devs) {
715-
ctx_info[d.dev_id] = std::make_pair(d, 0);
716-
}
717-
718-
for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
719-
const int key = std::get<0>(sorted_key_attrs_[i]);
720-
const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
721-
const int type = std::get<2>(sorted_key_attrs_[i]);
722-
auto& buf = merge_buf_[key];
723-
Context ctx;
724-
size_t min_size = std::numeric_limits<size_t>::max();
725-
for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) {
726-
size_t size = it->second.second;
727-
if (size <= min_size) {
728-
ctx = it->second.first;
729-
min_size = size;
730-
}
731-
}
732-
// Delayed allocation - as the dense merged buffer might not be used at all if push()
733-
// only sees sparse arrays
734-
bool delay_alloc = true;
735-
buf.merged = NDArray(shape, ctx, delay_alloc, type);
736-
ctx_info[ctx.dev_id].second += shape.Size();
737-
}
738-
inited_ = true;
739-
}
740-
741-
std::vector<KeyAttrs> sorted_key_attrs_;
742754
/// \brief temporal space for pushing and pulling
743755
struct BufferEntry {
744756
/// \brief the dense merged value for reduce and broadcast operations
@@ -773,7 +785,10 @@ class CommDevice : public Comm {
773785
NDArray sparse_merged;
774786
};
775787
std::unordered_map<int, BufferEntry> merge_buf_;
788+
789+
public:
776790
bool inited_;
791+
std::vector<KeyAttrs> sorted_key_attrs_;
777792
};
778793

779794
} // namespace kvstore

0 commit comments

Comments
 (0)