forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCollectives.cpp
92 lines (71 loc) · 2.57 KB
/
Collectives.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include "Collectives.hpp"
#include "General.hpp"
#include <vector>
using namespace thd;
int THDGetRank() {
return dataChannel->getRank();
}
int THDGetNumProcesses() {
return dataChannel->getNumProcesses();
}
void THDAllReduce(THDTensorDescriptor* desc, THDReduceOp operation, THDGroup group) {
dataChannel->allReduce(*desc, operation, group);
}
void THDReduce(THDTensorDescriptor* desc, THDReduceOp operation,
int dst_rank, THDGroup group) {
dataChannel->reduce(*desc, operation, dst_rank, group);
}
void THDBroadcast(THDTensorDescriptor* desc, int src_rank, THDGroup group) {
dataChannel->broadcast(*desc, src_rank, group);
}
THDRequest* THDIsend(THDTensorDescriptor* desc, int dst_rank) {
return dataChannel->isend(*desc, dst_rank);
}
THDRequest* THDIrecv(THDTensorDescriptor* desc, int src_rank) {
return dataChannel->ireceive(*desc, src_rank);
}
void THDSend(THDTensorDescriptor* desc, int dst_rank) {
dataChannel->send(*desc, dst_rank);
}
void THDRecvAnySource(THDTensorDescriptor* desc) {
dataChannel->receive(*desc);
}
void THDRecv(THDTensorDescriptor* desc, int src_rank) {
dataChannel->receive(*desc, src_rank);
}
void THDAllGather(THDTensorDescriptor** output, size_t len,
THDTensorDescriptor* input, THDGroup group) {
std::vector<Tensor*> v_output(output, output + len);
dataChannel->allGather(v_output, *input, group);
}
void THDGatherSend(THDTensorDescriptor* input, int dst_rank, THDGroup group) {
std::vector<Tensor*> v_output;
dataChannel->gather(v_output, *input, dst_rank, group);
}
void THDGatherRecv(THDTensorDescriptor** output, size_t len,
THDTensorDescriptor* input, THDGroup group) {
std::vector<Tensor*> v_output(output, output + len);
dataChannel->gather(v_output, *input, dataChannel->getRank(), group);
}
void THDScatterSend(THDTensorDescriptor** input, size_t len,
THDTensorDescriptor* output, THDGroup group) {
std::vector<Tensor*> v_input(input, input + len);
dataChannel->scatter(v_input, *output, dataChannel->getRank(), group);
}
void THDScatterRecv(THDTensorDescriptor* output, int src_rank, THDGroup group) {
std::vector<Tensor*> v_input;
dataChannel->scatter(v_input, *output, src_rank, group);
}
void THDBarrier(THDGroup group) {
dataChannel->barrier(group);
}
THDGroup THDNewGroup(const int *ranks, size_t len) {
std::vector<int> v_ranks(ranks, ranks + len);
return dataChannel->newGroup(v_ranks);
}
bool THDRequest_isCompleted(THDRequest* request) {
return request->isCompleted();
}
void THDRequest_wait(THDRequest* request) {
request->wait();
}