Skip to content

Commit

Permalink
NCCL logical op support S1 to B (#4772)
Browse files Browse the repository at this point in the history
* NcclLogialOpAllGatherNoncontinuous S1 to B

* delete useless file

* fix user op attr err

* fix tmp_buffer

* Update oneflow/user/kernels/nccl_logical_kernels.cpp

* fix data size check in 2d s1-b

Co-authored-by: guo-ran <360112263@qq.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Juncheng <liujuncheng1022@gmail.com>
Former-commit-id: 92337ef
  • Loading branch information
4 people authored Apr 29, 2021
1 parent 87e83a6 commit ce04eff
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 4 deletions.
14 changes: 14 additions & 0 deletions oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,20 @@ bool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp,
.Build()
.op_conf();
return true;
} else if ((src_sbp.has_split_parallel() && dst_sbp.has_broadcast_parallel())
&& (src_sbp.split_parallel().axis() > 0)
&& (logical_blob_desc.shape().At(src_sbp.split_parallel().axis()) % parallel_num
== 0)) {
// S(1)->B : AllGather Noncontinuous
*ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-S2B-" + NewUniqueId())
.Op("_nccl_logical_all_gather_noncontinuous")
.Input("in", lbn)
.Output("out")
.Attr<int64_t>("in_split_axis", src_sbp.split_parallel().axis())
.ScopeSymbolId(scope_symbol_id)
.Build()
.op_conf();
return true;
} else if ((src_sbp.has_split_parallel() && dst_sbp.has_split_parallel())
&& (src_sbp.split_parallel().axis() != dst_sbp.split_parallel().axis())
&& (logical_blob_desc.shape().At(src_sbp.split_parallel().axis()) % parallel_num == 0)
Expand Down
8 changes: 4 additions & 4 deletions oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const int64_t dtype_size = GetSizeOfDataType(in->data_type());
int64_t data_size = GetCudaAlignedSize(in->shape().elem_cnt() * dtype_size);
int64_t data_size = GetCudaAlignedSize(out->shape().elem_cnt() * dtype_size);
void* unpack_from_ptr = tmp_buffer->mut_dptr();
CHECK_EQ(tmp_buffer->shape().elem_cnt(), data_size);

Expand Down Expand Up @@ -188,9 +188,9 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern
};

size_t Infer2DSameDim0AllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) {
const user_op::TensorDesc* in_tensor = ctx->TensorDesc4ArgNameAndIndex("in", 0);
return GetCudaAlignedSize(in_tensor->shape().elem_cnt()
* GetSizeOfDataType(in_tensor->data_type()));
const user_op::TensorDesc* out_tensor = ctx->TensorDesc4ArgNameAndIndex("out", 0);
return GetCudaAlignedSize(out_tensor->shape().elem_cnt()
* GetSizeOfDataType(out_tensor->data_type()));
}

template<typename T>
Expand Down
83 changes: 83 additions & 0 deletions oneflow/user/kernels/nccl_logical_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,74 @@ class NcclLogicalAllGatherKernel final : public user_op::OpKernel {
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

template<typename T>
class NcclLogicalAllGatherNoncontinuous final : public user_op::OpKernel {
public:
NcclLogicalAllGatherNoncontinuous() = default;
~NcclLogicalAllGatherNoncontinuous() override = default;

std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
return std::make_shared<NcclLogicalKernelCommState>(ctx);
}

private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
auto* nccl_comm = dynamic_cast<NcclLogicalKernelCommState*>(state);
CHECK(nccl_comm != nullptr);
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const int64_t dtype_size = GetSizeOfDataType(in->data_type());
int64_t data_size = GetCudaAlignedSize(out->shape().elem_cnt() * dtype_size);
void* unpack_from_ptr = tmp_buffer->mut_dptr();
CHECK_EQ(tmp_buffer->shape().elem_cnt(), data_size);

CHECK_EQ(in->data_type(), out->data_type());
const int64_t num_ranks = ctx->parallel_ctx().parallel_num();
const int64_t in_split_axis = ctx->Attr<int64_t>("in_split_axis");

DimVector logical_shape_dim_vec;
in->shape().ToDimVector(&logical_shape_dim_vec);
logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks;

// NOTE(chengcheng): Do AllGather
CHECK_EQ(in->shape().elem_cnt() * num_ranks, out->shape().elem_cnt());
OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape().elem_cnt(),
GetNcclDataType(in->data_type()), nccl_comm->comm(),
ctx->device_ctx()->cuda_stream()));

CHECK_GT(in_split_axis, 0);
// NOTE(chengcheng): Do unpack.
DimVector unpack_from_dim_vec = logical_shape_dim_vec;
CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0);
unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks;
unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);
const Shape unpack_from_shape(unpack_from_dim_vec);
DimVector transpose_out_dim_vec;
std::vector<int32_t> perm;
FOR_RANGE(int64_t, i, 1, unpack_from_shape.NumAxes()) {
perm.push_back(i);
transpose_out_dim_vec.push_back(unpack_from_shape.At(i));
}
perm.insert(perm.begin() + in_split_axis, 0);
transpose_out_dim_vec.insert(transpose_out_dim_vec.begin() + in_split_axis,
unpack_from_shape.At(0));
const Shape transpose_out_shape(transpose_out_dim_vec);
NewKernelUtil<DeviceType::kGPU>::Transpose(
ctx->device_ctx(), unpack_from_shape.NumAxes(), unpack_from_shape, transpose_out_shape,
perm, unpack_from_shape.elem_cnt(), reinterpret_cast<const T*>(unpack_from_ptr),
out->mut_dptr<T>());
};
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

size_t InferAllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) {
const user_op::TensorDesc* out_tensor = ctx->TensorDesc4ArgNameAndIndex("out", 0);
return GetCudaAlignedSize(out_tensor->shape().elem_cnt()
* GetSizeOfDataType(out_tensor->data_type()));
}

template<typename T>
class NcclLogicalS2SKernel final : public user_op::OpKernel {
public:
Expand Down Expand Up @@ -278,6 +346,21 @@ REGISTER_USER_KERNEL("_nccl_logical_all_gather")
.SetCreateFn<NcclLogicalAllGatherKernel>()
.SetIsMatchedHob(user_op::HobDeviceTag() == "gpu");

#define REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(dtype) \
REGISTER_USER_KERNEL("_nccl_logical_all_gather_noncontinuous") \
.SetCreateFn<NcclLogicalAllGatherNoncontinuous<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") \
& (user_op::HobDataType("in", 0) == GetDataType<dtype>::value) \
& (user_op::HobDataType("out", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn(InferAllGatherNoncontinuousKernelTmpBufferSize);

REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int8_t)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int32_t)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int64_t)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(double)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float16)

#define REGISTER_S2S_KERNEL(dtype) \
REGISTER_USER_KERNEL("_nccl_logical_s2s") \
.SetCreateFn<NcclLogicalS2SKernel<dtype>>() \
Expand Down
40 changes: 40 additions & 0 deletions oneflow/user/ops/nccl_logical_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,46 @@ REGISTER_USER_OP("_nccl_logical_all_gather")
return Maybe<void>::Ok();
});

REGISTER_USER_OP("_nccl_logical_all_gather_noncontinuous")
.Input("in")
.Output("out")
.Attr<int64_t>("in_split_axis", -1)
.SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->Shape4ArgNameAndIndex("out", 0) = *ctx->Shape4ArgNameAndIndex("in", 0);
*ctx->IsDynamic4ArgNameAndIndex("out", 0) = *ctx->IsDynamic4ArgNameAndIndex("in", 0);
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->Dtype4ArgNameAndIndex("out", 0) = *ctx->Dtype4ArgNameAndIndex("in", 0);
return Maybe<void>::Ok();
})
.SetParallelDistributionInferFn([](user_op::InferParallelDistributionFnContext* ctx)
-> Maybe<void> {
const ParallelDistribution& in_dis_hint =
ctx->ParallelDistributionHint4InputArgNameAndIndex("in", 0);
CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1);
const int64_t in_split_axis = ctx->user_op_conf().attr<int64_t>("in_split_axis");
CHECK_GE_OR_RETURN(in_split_axis, 1);
for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) {
CHECK_OR_RETURN(sbp_hint.has_split_parallel());
CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis);
}

ParallelDistribution* in_distribution = ctx->ParallelDistribution4ArgNameAndIndex("in", 0);
ParallelDistribution* out_distribution = ctx->ParallelDistribution4ArgNameAndIndex("out", 0);
in_distribution->clear_sbp_parallel();
out_distribution->clear_sbp_parallel();

// S(1)->(B)
const Shape& parallel_hierarchy = ctx->parallel_hierarchy();
CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1);
for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {
in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis);
out_distribution->add_sbp_parallel()->mutable_broadcast_parallel();
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP("_nccl_logical_s2s")
.Input("in")
.Output("out")
Expand Down

0 comments on commit ce04eff

Please sign in to comment.