-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-331] Single machine All Reduce Topology-aware Communication (Updated) #11591
Changes from 1 commit
9678143
d5e51d6
0708dbc
5590920
4f8f58b
908534a
25cbbdc
310ee4d
9cce8ea
4d2790d
b5b42bc
6327ceb
8694fe7
c6cd67a
7466c4d
153ec0b
7c61b6c
cc935a2
ba60aaa
972e9c0
6627dcf
4de89a7
317c66b
c364fd3
3241d71
bd926bf
0e1a704
47b0b63
781a7fe
a29f284
b310ab4
abcb10e
24b9c62
7d0da7b
18c1700
c65a620
a70b1b8
263a4cb
628ba6e
b3f3235
a0e1366
63fd14e
6c0bff8
9f5c24a
9cc24d0
691d5ac
67b0db0
c8ebb87
5f7da5e
16b8fb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,9 +83,14 @@ class CommDeviceTree : public CommDevice { | |
} | ||
} | ||
|
||
// src is sliced shape | ||
// copy_buf not sliced | ||
// merged not sliced | ||
/** | ||
* \brief Reduce src to tree_merge_buf_ | ||
* \param key is the id of the gradient we are doing Reduce on | ||
* \param src is the array of values located on different GPUs | ||
* \param root is the id of the GPU we want to send result of reduce to | ||
* \param merged_row is the id of the slice we are taking | ||
* \param priority the priority of the operation | ||
*/ | ||
const NDArray& ReduceInner(int key, const std::vector<NDArray>& src, int root, | ||
int merged_row, int priority) { | ||
std::vector<std::vector<NDArray>> reduce(devs_.size()); | ||
|
@@ -98,8 +103,8 @@ class CommDeviceTree : public CommDevice { | |
if (stype == kDefaultStorage) { | ||
// Copy everything into buf.merged for each gpu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is copying to buf.merged required? Can we avoid this copy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can avoid this copy, but it makes the code harder to read. Since the source and destination belong to the same GPU, the gain is minimal (10 samples/s i.e. 720 samples/s to 730 samples/s). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ctcyang just asking: how did you test the throughput (as you mentioned, 730 samples/s)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested the throughput difference on a similar intra-GPU Testing on VGG-16 on an older commit, I got the following. Geomean difference across these batch sizes suggests getting rid of one CopyFromTo makes it 2.2% faster. If you two think these 2 optimizations--(i) Eliminate copy from buf.merged in reduce, (ii) Eliminate copy from buf.merged to
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense to me, because p2.16x topology is fully connected using GPUDirect PCI-E. In fully connected networks, tree will never outperform device mode. Tree's advantage on p3.16x is that device mode is forced to use 4 PCI-E links (not even GPUDirect, so must go through CPU/QPI), but tree exclusively uses NVLink, so it avoids CPU/QPI latency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, p2.x16 has a worse communication topology compared to p3.x16. |
||
for (size_t i = 0; i < src.size(); ++i) { | ||
int start = scan_[root][depth_ ]; | ||
int end = scan_[root][depth_+1]; | ||
int start = scan_[root][depth_]; | ||
int end = scan_[root][depth_+1]; | ||
|
||
for (int j = start; j < end; ++j) { | ||
int topo_id = topology[j]; | ||
|
@@ -113,13 +118,13 @@ class CommDeviceTree : public CommDevice { | |
|
||
for (int level = depth_; level > 0; --level) { | ||
int start = scan_[root][level ]; | ||
int end = scan_[root][level+1]; | ||
int end = scan_[root][level+1]; | ||
|
||
unsigned is_dest = 0; | ||
int dest_id = 0; | ||
int dest_id = 0; | ||
for (int j = start; j < end; ++j) { | ||
int topo_id = topology[j]; | ||
dest_id = (is_dest == 0) ? topo_id : dest_id; | ||
dest_id = (is_dest == 0) ? topo_id : dest_id; | ||
|
||
TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key]; | ||
TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key]; | ||
|
@@ -141,7 +146,7 @@ class CommDeviceTree : public CommDevice { | |
} | ||
|
||
start = scan_[root][level-1]; | ||
end = scan_[root][level ]; | ||
end = scan_[root][level]; | ||
for (int i = start; i < end; ++i) { | ||
int gpu_id = topology[i]; | ||
|
||
|
@@ -158,7 +163,7 @@ class CommDeviceTree : public CommDevice { | |
} | ||
} | ||
} else { | ||
LOG(WARNING) << "Only dense input supported for now"; | ||
LOG(FATAL) << "Only dense input supported for now"; | ||
} | ||
|
||
int topo_id = topology[0]; | ||
|
@@ -231,7 +236,7 @@ class CommDeviceTree : public CommDevice { | |
} | ||
|
||
// Copy from list of small NDArrays to one big NDArray, which is returned | ||
int gpu_id = 0; | ||
int gpu_id = 0; | ||
return src[gpu_id]; | ||
} else { | ||
// sparse reduce | ||
|
@@ -252,13 +257,13 @@ class CommDeviceTree : public CommDevice { | |
|
||
for (int level = 1; level <= depth_; ++level) { | ||
int start = scan_[root][level]; | ||
int end = scan_[root][level+1]; | ||
int end = scan_[root][level+1]; | ||
|
||
unsigned is_src = 0; | ||
int src_id = 0; | ||
int src_id = 0; | ||
for (int j = start; j < end; ++j) { | ||
int topo_id = topology[j]; | ||
src_id = (is_src == 0) ? topo_id : src_id; | ||
src_id = (is_src == 0) ? topo_id : src_id; | ||
|
||
if (is_src && src_id != topo_id) { | ||
CopyFromTo(temp[src_id], dst[topo_id], priority); | ||
|
@@ -392,8 +397,8 @@ class CommDeviceTree : public CommDevice { | |
else | ||
key_dist[shape.Size()]++; | ||
|
||
int start = scan_[0][depth_ ]; | ||
int end = scan_[0][depth_+1]; | ||
int start = scan_[0][depth_]; | ||
int end = scan_[0][depth_+1]; | ||
|
||
// In order to generalize to any number of GPUs, we use strategy of having | ||
// found the mapping from 0, 1, ..., n_gpus to dev_id i.e. | ||
|
@@ -484,10 +489,10 @@ class CommDeviceTree : public CommDevice { | |
std::vector<Context> devs_; | ||
|
||
/// \brief Highest numbered device | ||
int max_dev_; | ||
int depth_; | ||
int gpuarray_bound_; | ||
bool backtrack_; | ||
int max_dev_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to see where this variable is used to ensure that cases when gpus '1,5,3,7' are given work. But it looks like this variable is not used? Please remove this then |
||
int depth_; | ||
int gpuarray_bound_; | ||
bool backtrack_; | ||
float link_usage_penalty_; | ||
|
||
/// \brief constant for maximum size of recv buffer per GPU | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add documentation for this fucntion? e.g. what's
merged_row