Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ccl allreduce #8760

Merged
merged 32 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c681984
refactor_ccl_allreduce
clackhan Jul 27, 2022
5269dc5
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jul 27, 2022
f2a9e90
reslove comment
clackhan Jul 27, 2022
b7ff8df
move collective_communication/ to oneflow/user/kernels/
clackhan Jul 28, 2022
9802914
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jul 28, 2022
63e0c1d
fix static check error
clackhan Jul 28, 2022
e427a2f
fix static check error
clackhan Jul 28, 2022
e1866ee
Merge branch 'master' into refactor_ccl_allreduce
clackhan Jul 28, 2022
278b0f1
refine
clackhan Jul 28, 2022
0f26547
refine
clackhan Jul 28, 2022
45ab350
Merge branch 'master' into refactor_ccl_allreduce
clackhan Jul 29, 2022
3740d33
refine
clackhan Jul 29, 2022
cb60a93
Merge branch 'refactor_ccl_allreduce' of https://github.com/Oneflow-I…
clackhan Jul 29, 2022
405db46
use collective_communication namespace
clackhan Jul 29, 2022
d248307
add UserOpRegistryMgr::IsOpKernelRegistered
clackhan Jul 29, 2022
6c51505
Merge branch 'master' into refactor_ccl_allreduce
clackhan Jul 29, 2022
3c5e4fd
Merge branch 'master' into refactor_ccl_allreduce
hjchen2 Jul 29, 2022
d8b6f1f
rename CommunicationContext and ccl
clackhan Jul 30, 2022
96ed40f
remove CollectiveCommunicationFactory
clackhan Jul 30, 2022
3380ccc
Merge branch 'refactor_ccl_allreduce' of https://github.com/Oneflow-I…
clackhan Jul 30, 2022
6c9ff49
Merge branch 'master' into refactor_ccl_allreduce
clackhan Jul 30, 2022
edeef95
refine
clackhan Jul 30, 2022
bb9c1cf
Merge branch 'refactor_ccl_allreduce' of https://github.com/Oneflow-I…
clackhan Jul 30, 2022
ca74cb7
Merge branch 'master' into refactor_ccl_allreduce
mergify[bot] Jul 30, 2022
213807a
reslove comment and fix static check
clackhan Jul 31, 2022
ae9d677
Merge branch 'refactor_ccl_allreduce' of https://github.com/Oneflow-I…
clackhan Jul 31, 2022
fa8e738
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jul 31, 2022
a144816
minor fix
clackhan Jul 31, 2022
70541dc
fix static check error
clackhan Jul 31, 2022
f3785cd
Merge branch 'master' into refactor_ccl_allreduce
clackhan Jul 31, 2022
9f47231
Merge branch 'master' into refactor_ccl_allreduce
mergify[bot] Jul 31, 2022
8ee7ad6
Merge branch 'master' into refactor_ccl_allreduce
mergify[bot] Jul 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/api/python/functional/dispatch_stateful_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
});
m.add_functor("DispatchEagerNcclAllReduce",
m.add_functor("DispatchEagerCclAllReduce",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::string& parallel_conf, bool async_launch) -> Maybe<Tensor> {
MutableAttrMap attrs;
Expand Down
4 changes: 2 additions & 2 deletions oneflow/api/python/functional/dispatch_stateful_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_power, Float lambda1, Float lambda2, Float beta, Float weight_decay=0) => DispatchFtrlUpdate"
bind_python: True

- name: "dispatch_eager_nccl_all_reduce"
signature: "Tensor (OpExpr op, Tensor input, String parallel_conf, Bool async_launch=False) => DispatchEagerNcclAllReduce"
- name: "dispatch_eager_ccl_all_reduce"
signature: "Tensor (OpExpr op, Tensor input, String parallel_conf, Bool async_launch=False) => DispatchEagerCclAllReduce"
bind_python: True

- name: "dispatch_raw_reader"
Expand Down
54 changes: 48 additions & 6 deletions oneflow/core/boxing/ccl_boxing_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,55 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/user_op_registry_manager.h"

namespace oneflow {

namespace {

class EagerBoxingKernelRegContext final : public user_op::KernelRegContext {
public:
explicit EagerBoxingKernelRegContext(DeviceType device_type) : device_type_(device_type) {}
~EagerBoxingKernelRegContext() = default;

DeviceType device_type() const override { return device_type_; }
const ParallelContext& parallel_ctx() const override { PRINT_BUG_PROMPT_AND_ABORT(); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
PRINT_BUG_PROMPT_AND_ABORT();
}
const std::vector<std::pair<std::string, int32_t>>& inputs() const override {
PRINT_BUG_PROMPT_AND_ABORT();
}
const std::vector<std::pair<std::string, int32_t>>& outputs() const override {
PRINT_BUG_PROMPT_AND_ABORT();
}

const user_op::UserOpConfWrapper& user_op_conf() const override { PRINT_BUG_PROMPT_AND_ABORT(); }

const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
PRINT_BUG_PROMPT_AND_ABORT();
}

private:
DeviceType device_type_;
};

Maybe<bool> RawCheckCclKernelRegistered(const std::string& op_type_name, DeviceType device_type) {
EagerBoxingKernelRegContext reg_ctx(device_type);
return user_op::UserOpRegistryMgr::Get().IsOpKernelRegistered(op_type_name, reg_ctx);
}

static constexpr auto* CheckCclKernelRegistered =
DECORATE(&RawCheckCclKernelRegistered, ThreadLocalCachedCopiable);

Maybe<void> RawCheckCclP2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
Expand All @@ -33,8 +72,9 @@ Maybe<void> RawCheckCclP2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));

CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
CHECK_OR_RETURN( // NOLINT
JUST(CheckCclKernelRegistered("eager_ccl_all_reduce", // NOLINT
in->placement()->device_type()))); // NOLINT
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Expand All @@ -53,8 +93,9 @@ Maybe<void> RawCheckCclP2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);

CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
CHECK_OR_RETURN( // NOLINT
JUST(CheckCclKernelRegistered("eager_nccl_reduce_scatter", // NOLINT
in->placement()->device_type()))); // NOLINT
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Expand All @@ -74,8 +115,9 @@ Maybe<void> RawCheckCclS2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);

CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
CHECK_OR_RETURN( // NOLINT
JUST(CheckCclKernelRegistered("eager_nccl_all_gather", // NOLINT
in->placement()->device_type()))); // NOLINT
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Expand Down
107 changes: 0 additions & 107 deletions oneflow/core/ccl/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,113 +63,6 @@ void VecAdd(size_t size, T* out, const T* in0, const T* in1) {

} // namespace

template<typename T, ReduceType reduce_type>
struct DtypeAllReduce;

template<typename T>
struct DtypeAllReduce<T, kSum> {
static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt,
Symbol<ParallelDesc> parallel_desc) {
int64_t parallel_num = parallel_desc->parallel_num();
if (parallel_num == 1) {
if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); }
return Maybe<void>::Ok();
}
const T* in = reinterpret_cast<const T*>(void_in);
T* out = reinterpret_cast<T*>(void_out);
BalancedSplitter bs(elem_cnt, parallel_num);
auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());
Optional<int64_t> parallel_id;
JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
for (int64_t i = 0, part_id = JUST(parallel_id); i < parallel_num - 1;
++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = nullptr;
if (i == 0) {
send_ptr = &in[bs.At(send_part_id).begin()];
} else {
send_ptr = &out[bs.At(send_part_id).begin()];
}
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = recv_buffer.get();
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
const T* cur_in = &in[bs.At(recv_part_id).begin()];
T* cur_out = &out[bs.At(recv_part_id).begin()];
if (recv_size > 0) { VecAdd(recv_size, cur_out, cur_in, recv_ptr); }
}
for (int64_t i = 0, part_id = RingIncrease(JUST(parallel_id), parallel_num);
i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = &out[bs.At(send_part_id).begin()];
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = &out[bs.At(recv_part_id).begin()];
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
}
return Maybe<void>::Ok();
}
};

#define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call

DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, DtypeAllReduce, MAKE_ALL_REDUCE_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), CCL_REDUCE_TYPE_CTRV_SEQ);

#undef MAKE_ALL_REDUCE_ENTRY

template<>
Maybe<void> AllReduce<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream) {
return SwitchDtypeAllReduce(SwitchCase(dtype, reduce_type), in, out, elem_cnt, parallel_desc);
}

template<typename T, ReduceType reduce_type>
struct DtypeReduceScatter;

Expand Down
5 changes: 0 additions & 5 deletions oneflow/core/ccl/ccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ enum ReduceType {
MAKE_TYPED_CTRV_SEQ(ReduceType, \
OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, CCL_REDUCE_TYPE_SEQ))

template<DeviceType device_type>
Maybe<void> AllReduce(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream);

template<DeviceType device_type>
Maybe<void> ReduceScatter(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
Expand Down
24 changes: 24 additions & 0 deletions oneflow/core/framework/user_op_registry_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,30 @@ Maybe<const OpKernelRegistryResult*> UserOpRegistryMgr::GetOpKernelRegistryResul
return ret;
}

Maybe<bool> UserOpRegistryMgr::IsOpKernelRegistered(const std::string& op_type_name,
const KernelRegContext& ctx) {
auto it = op_kernel_reg_result_.find(op_type_name);
if (it == op_kernel_reg_result_.end()) { return false; }
const OpKernelRegistryResult* ret = nullptr;
for (const auto& reg_val : it->second) {
if (reg_val.is_matched_hob->get(ctx)) {
if (ret != nullptr) {
std::vector<std::string> debug_msgs;
for (const auto& local_reg_val : it->second) {
if (local_reg_val.is_matched_hob->get(ctx)) {
debug_msgs.emplace_back(local_reg_val.is_matched_hob->DebugStr(ctx));
}
}
return Error::MultipleOpKernelsMatchedError(debug_msgs)
<< "There are more than one kernels matching Current OperatorConf: " << op_type_name;
}
ret = &reg_val;
}
}
if (ret == nullptr) { return false; }
return true;
}

} // namespace user_op

} // namespace oneflow
1 change: 1 addition & 0 deletions oneflow/core/framework/user_op_registry_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class UserOpRegistryMgr final {
Maybe<void> Register(OpKernelRegistryResult result);
Maybe<const OpKernelRegistryResult*> GetOpKernelRegistryResult(const std::string& op_type_name,
const KernelRegContext& ctx);
Maybe<bool> IsOpKernelRegistered(const std::string& op_type_name, const KernelRegContext& ctx);

const HashMap<std::string, OpRegistryResult>& GetAllOpRegistryResults() {
return op_reg_result_;
Expand Down
Loading