Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-331] Single machine All Reduce Topology-aware Communication (Updated) #11591

Merged
merged 50 commits into from
Jul 24, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
9678143
add multiroot all-reduce communication pattern
Jun 4, 2018
d5e51d6
fix bug with UpdateWeight
Jun 4, 2018
0708dbc
fix PCI-E links appearing in weight matrix bug
Jun 4, 2018
5590920
optimization to skip CopyFromTo in ReduceInner gains a bit of throughput
Jun 4, 2018
4f8f58b
remove unnecessary if statement
Jun 5, 2018
908534a
Add tests
Jun 15, 2018
25cbbdc
add more tests, 6 tests left to add
Jun 16, 2018
310ee4d
get rid of some dead code
Jun 16, 2018
9cce8ea
Add comments
Jun 18, 2018
4d2790d
Add randomized tests for backtrack and kernighan-lin
Jun 18, 2018
b5b42bc
Fix Postprocess
Jun 18, 2018
6327ceb
Add switch for first valid tree when num_gpus > 8, and for maximum we…
Jun 18, 2018
8694fe7
Kernighan-Lin seems to find better trees
Jun 18, 2018
c6cd67a
get rid of printfs
Jun 20, 2018
7466c4d
change defaults
Jun 21, 2018
153ec0b
Merge branch 'feature_multirootv9' of https://github.com/ctcyang/incu…
Jun 21, 2018
7c61b6c
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jun 21, 2018
cc935a2
inherit from CommDevice instead of Comm
Jun 22, 2018
ba60aaa
Fix lint errors
Jun 22, 2018
972e9c0
Add Python test using MXNET_KVSTORE_USETREE, fix CMake compilation pr…
Jun 27, 2018
6627dcf
fix lint errors
Jun 27, 2018
4de89a7
better header guard that works for tests
Jun 27, 2018
317c66b
get rid of unused variable warning
Jun 27, 2018
c364fd3
retrigger jenkins
Jun 28, 2018
3241d71
resolve 2 comments
Jun 29, 2018
bd926bf
address comment using Class to do test, get rid of extraneous test, u…
Jul 2, 2018
0e1a704
resolve merge conflicts
Jul 2, 2018
47b0b63
Merge remote-tracking branch 'apache/master' into feature_multirootv9
Jul 5, 2018
781a7fe
Merge remote-tracking branch 'apache/master' into feature_multirootv9…
Jul 6, 2018
a29f284
address comments
Jul 13, 2018
b310ab4
Merge branch 'feature_multirootv9merge2' into feature_multirootv9merge
Jul 13, 2018
abcb10e
Merge remote-tracking branch 'apache/master' into feature_multirootv9…
Jul 13, 2018
24b9c62
Merge remote-tracking branch 'apache/master' into feature_multirootv9…
Jul 20, 2018
7d0da7b
Merge remote-tracking branch 'apache/master' into feature_multirootv9…
Jul 20, 2018
18c1700
fix a few bugs
Jul 21, 2018
c65a620
get rid of printfs
Jul 21, 2018
a70b1b8
Merge branch 'feature_multirootv9merge3' into feature_multirootv9
Jul 21, 2018
263a4cb
Merge remote-tracking branch 'apache/master' into feature_multirootv9
Jul 21, 2018
628ba6e
get rid of print
Jul 21, 2018
b3f3235
Merge branch 'feature_multirootv9' into feature_multirootv9merge
Jul 21, 2018
a0e1366
Comment out test for now
Jul 23, 2018
63fd14e
fix 2 more bugs
Jul 23, 2018
6c0bff8
Merge branch 'feature_multirootv9merge3' into feature_multirootv9merge
Jul 23, 2018
9f5c24a
fix segfault
Jul 23, 2018
9cc24d0
change PrintVector, PrintTopo, PrintMatrix to LOG(INFO) instead of st…
Jul 24, 2018
691d5ac
Merge branch 'feature_multiv9merge4' into feature_multirootv9merge
Jul 24, 2018
67b0db0
Fix code alignment
Jul 24, 2018
c8ebb87
get rid of todo
Jul 24, 2018
5f7da5e
Make changes to env variable names to indicate they are TREE-related
Jul 24, 2018
16b8fb4
Add note saying when ARRAY_BOUND env var takes effect
Jul 24, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address comments
  • Loading branch information
Carl Yang committed Jul 13, 2018
commit a29f284e1f3dac34ff32af4aeb091b799332e910
45 changes: 25 additions & 20 deletions src/kvstore/comm_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,14 @@ class CommDeviceTree : public CommDevice {
}
}

// src is sliced shape
// copy_buf not sliced
// merged not sliced
/**
* \brief Reduce src to tree_merge_buf_
* \param key is the id of the gradient we are doing Reduce on
* \param src is the array of values located on different GPUs
* \param root is the id of the GPU we want to send result of reduce to
* \param merged_row is the id of the slice we are taking
* \param priority the priority of the operation
*/
const NDArray& ReduceInner(int key, const std::vector<NDArray>& src, int root,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add documentation for this fucntion? e.g. what's merged_row

int merged_row, int priority) {
std::vector<std::vector<NDArray>> reduce(devs_.size());
Expand All @@ -98,8 +103,8 @@ class CommDeviceTree : public CommDevice {
if (stype == kDefaultStorage) {
// Copy everything into buf.merged for each gpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is copying to buf.merged required? Can we avoid this copy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid this copy, but it makes the code harder to read. Since the source and destination belong to the same GPU, the gain is minimal (10 samples/s i.e. 720 samples/s to 730 samples/s).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ctcyang just asking: how did you test the throughput (as you mentioned, 730 samples/s)?

Copy link
Contributor Author

@ctcyang ctcyang Jul 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the throughput difference on a similar intra-GPU CopyFromTo, by commenting out Line 306 in comm_tree.h. This ruins the correctness of the output, but it gives an idea of how much savings can be gotten. Back then, I was considering using a combined PushPull API, so that the dst array is known at the time of push. This saves one CopyFromTo during Broadcast from the temporary buffer buf.merged to dst array.

Testing on VGG-16 on an older commit, I got the following. Geomean difference across these batch sizes suggests getting rid of one CopyFromTo makes it 2.2% faster. If you two think these 2 optimizations--(i) Eliminate copy from buf.merged in reduce, (ii) Eliminate copy from buf.merged to dst in broadcast by using PushPull interface--I can add them as a separate PR after this one is accepted, because the PushPull interface depends on the Intel's C API addition (#10696).


v6: Push, Pull interface - One more intra-GPU CopyFromTo than v7
v7: PushPull combined AllReduce interface
BS: Batch size per GPU (8 GPUs total)

Throughput (samples/s)

fp32
BS | v6   | v7
----------------
4  | 711  | 745
8  | 999  | 1035
16 | 1449 | 1478
32 | 1638 | 1672
64 | 1695 | 1739

fp16
BS | v6   | v7
----------------
8  | 1552 | 1599
16 | 2127 | 2163
32 | 2916 | 2910
64 | 2720 | 2775
128| 2518 | 2532

Copy link

@congxie1108 congxie1108 Jul 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Have you tried to use tools/bandwidth/measure.py to test the throughput?
According to my experiments (using resnet), tree mode outperforms device mode on p3.x16.
However, on p2.x16, the performance of tree mode seems worse than device mode.
Could you double-check that?

Copy link
Contributor Author

@ctcyang ctcyang Jul 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me, because p2.16x topology is fully connected using GPUDirect PCI-E. In fully connected networks, tree will never outperform device mode.

Tree's advantage on p3.16x is that device mode is forced to use 4 PCI-E links (not even GPUDirect, so must go through CPU/QPI), but tree exclusively uses NVLink, so it avoids CPU/QPI latency.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, p2.x16 has a worse communication topology compared to p3.x16.
For p2.x16, the topology is very special. The communication net is composed of 2 fully connected subnets: the first 9 gpus and the last 7 gpus. Unlike p3.x16, there are no connections between the 2 components.
In that case, I expect that tree mode should still outperform device mode. However, surprisingly the device mode is better. Actually, the performance is similar to local mode on p2.x16, which seems weird to me.

for (size_t i = 0; i < src.size(); ++i) {
int start = scan_[root][depth_ ];
int end = scan_[root][depth_+1];
int start = scan_[root][depth_];
int end = scan_[root][depth_+1];

for (int j = start; j < end; ++j) {
int topo_id = topology[j];
Expand All @@ -113,13 +118,13 @@ class CommDeviceTree : public CommDevice {

for (int level = depth_; level > 0; --level) {
int start = scan_[root][level ];
int end = scan_[root][level+1];
int end = scan_[root][level+1];

unsigned is_dest = 0;
int dest_id = 0;
int dest_id = 0;
for (int j = start; j < end; ++j) {
int topo_id = topology[j];
dest_id = (is_dest == 0) ? topo_id : dest_id;
dest_id = (is_dest == 0) ? topo_id : dest_id;

TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key];
TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key];
Expand All @@ -141,7 +146,7 @@ class CommDeviceTree : public CommDevice {
}

start = scan_[root][level-1];
end = scan_[root][level ];
end = scan_[root][level];
for (int i = start; i < end; ++i) {
int gpu_id = topology[i];

Expand All @@ -158,7 +163,7 @@ class CommDeviceTree : public CommDevice {
}
}
} else {
LOG(WARNING) << "Only dense input supported for now";
LOG(FATAL) << "Only dense input supported for now";
}

int topo_id = topology[0];
Expand Down Expand Up @@ -231,7 +236,7 @@ class CommDeviceTree : public CommDevice {
}

// Copy from list of small NDArrays to one big NDArray, which is returned
int gpu_id = 0;
int gpu_id = 0;
return src[gpu_id];
} else {
// sparse reduce
Expand All @@ -252,13 +257,13 @@ class CommDeviceTree : public CommDevice {

for (int level = 1; level <= depth_; ++level) {
int start = scan_[root][level];
int end = scan_[root][level+1];
int end = scan_[root][level+1];

unsigned is_src = 0;
int src_id = 0;
int src_id = 0;
for (int j = start; j < end; ++j) {
int topo_id = topology[j];
src_id = (is_src == 0) ? topo_id : src_id;
src_id = (is_src == 0) ? topo_id : src_id;

if (is_src && src_id != topo_id) {
CopyFromTo(temp[src_id], dst[topo_id], priority);
Expand Down Expand Up @@ -392,8 +397,8 @@ class CommDeviceTree : public CommDevice {
else
key_dist[shape.Size()]++;

int start = scan_[0][depth_ ];
int end = scan_[0][depth_+1];
int start = scan_[0][depth_];
int end = scan_[0][depth_+1];

// In order to generalize to any number of GPUs, we use strategy of having
// found the mapping from 0, 1, ..., n_gpus to dev_id i.e.
Expand Down Expand Up @@ -484,10 +489,10 @@ class CommDeviceTree : public CommDevice {
std::vector<Context> devs_;

/// \brief Highest numbered device
int max_dev_;
int depth_;
int gpuarray_bound_;
bool backtrack_;
int max_dev_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to see where this variable is used to ensure that cases when gpus '1,5,3,7' are given work. But it looks like this variable is not used? Please remove this then

int depth_;
int gpuarray_bound_;
bool backtrack_;
float link_usage_penalty_;

/// \brief constant for maximum size of recv buffer per GPU
Expand Down
Loading