Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL logical op support S1 to B #4772

Merged
merged 13 commits into from
Apr 29, 2021
Merged
14 changes: 14 additions & 0 deletions oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,20 @@ bool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp,
.Build()
.op_conf();
return true;
} else if ((src_sbp.has_split_parallel() && dst_sbp.has_broadcast_parallel())
&& (src_sbp.split_parallel().axis() > 0)
&& (logical_blob_desc.shape().At(src_sbp.split_parallel().axis()) % parallel_num
== 0)) {
// S(1)->B : AllGather Noncontinuous
*ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-S2B-" + NewUniqueId())
.Op("_nccl_logical_all_gather_noncontinuous")
.Input("in", lbn)
.Output("out")
.Attr<int64_t>("in_split_axis", src_sbp.split_parallel().axis())
.ScopeSymbolId(scope_symbol_id)
.Build()
.op_conf();
return true;
} else if ((src_sbp.has_split_parallel() && dst_sbp.has_split_parallel())
&& (src_sbp.split_parallel().axis() != dst_sbp.split_parallel().axis())
&& (logical_blob_desc.shape().At(src_sbp.split_parallel().axis()) % parallel_num == 0)
Expand Down
8 changes: 4 additions & 4 deletions oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const int64_t dtype_size = GetSizeOfDataType(in->data_type());
int64_t data_size = GetCudaAlignedSize(in->shape().elem_cnt() * dtype_size);
int64_t data_size = GetCudaAlignedSize(out->shape().elem_cnt() * dtype_size);
void* unpack_from_ptr = tmp_buffer->mut_dptr();
CHECK_EQ(tmp_buffer->shape().elem_cnt(), data_size);

Expand Down Expand Up @@ -188,9 +188,9 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern
};

size_t Infer2DSameDim0AllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) {
const user_op::TensorDesc* in_tensor = ctx->TensorDesc4ArgNameAndIndex("in", 0);
return GetCudaAlignedSize(in_tensor->shape().elem_cnt()
* GetSizeOfDataType(in_tensor->data_type()));
const user_op::TensorDesc* out_tensor = ctx->TensorDesc4ArgNameAndIndex("out", 0);
return GetCudaAlignedSize(out_tensor->shape().elem_cnt()
* GetSizeOfDataType(out_tensor->data_type()));
}

template<typename T>
Expand Down
83 changes: 83 additions & 0 deletions oneflow/user/kernels/nccl_logical_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,74 @@ class NcclLogicalAllGatherKernel final : public user_op::OpKernel {
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

template<typename T>
class NcclLogicalAllGatherNoncontinuous final : public user_op::OpKernel {
public:
NcclLogicalAllGatherNoncontinuous() = default;
~NcclLogicalAllGatherNoncontinuous() override = default;

std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
return std::make_shared<NcclLogicalKernelCommState>(ctx);
}

private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
auto* nccl_comm = dynamic_cast<NcclLogicalKernelCommState*>(state);
CHECK(nccl_comm != nullptr);
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const int64_t dtype_size = GetSizeOfDataType(in->data_type());
int64_t data_size = GetCudaAlignedSize(out->shape().elem_cnt() * dtype_size);
void* unpack_from_ptr = tmp_buffer->mut_dptr();
CHECK_EQ(tmp_buffer->shape().elem_cnt(), data_size);

CHECK_EQ(in->data_type(), out->data_type());
const int64_t num_ranks = ctx->parallel_ctx().parallel_num();
const int64_t in_split_axis = ctx->Attr<int64_t>("in_split_axis");

DimVector logical_shape_dim_vec;
in->shape().ToDimVector(&logical_shape_dim_vec);
logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks;

// NOTE(chengcheng): Do AllGather
CHECK_EQ(in->shape().elem_cnt() * num_ranks, out->shape().elem_cnt());
OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape().elem_cnt(),
GetNcclDataType(in->data_type()), nccl_comm->comm(),
ctx->device_ctx()->cuda_stream()));

CHECK_GT(in_split_axis, 0);
// NOTE(chengcheng): Do unpack.
DimVector unpack_from_dim_vec = logical_shape_dim_vec;
CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0);
unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks;
unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);
const Shape unpack_from_shape(unpack_from_dim_vec);
DimVector transpose_out_dim_vec;
std::vector<int32_t> perm;
FOR_RANGE(int64_t, i, 1, unpack_from_shape.NumAxes()) {
perm.push_back(i);
transpose_out_dim_vec.push_back(unpack_from_shape.At(i));
}
perm.insert(perm.begin() + in_split_axis, 0);
transpose_out_dim_vec.insert(transpose_out_dim_vec.begin() + in_split_axis,
unpack_from_shape.At(0));
const Shape transpose_out_shape(transpose_out_dim_vec);
NewKernelUtil<DeviceType::kGPU>::Transpose(
ctx->device_ctx(), unpack_from_shape.NumAxes(), unpack_from_shape, transpose_out_shape,
perm, unpack_from_shape.elem_cnt(), reinterpret_cast<const T*>(unpack_from_ptr),
out->mut_dptr<T>());
};
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

size_t InferAllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) {
const user_op::TensorDesc* out_tensor = ctx->TensorDesc4ArgNameAndIndex("out", 0);
return GetCudaAlignedSize(out_tensor->shape().elem_cnt()
* GetSizeOfDataType(out_tensor->data_type()));
}

template<typename T>
class NcclLogicalS2SKernel final : public user_op::OpKernel {
public:
Expand Down Expand Up @@ -278,6 +346,21 @@ REGISTER_USER_KERNEL("_nccl_logical_all_gather")
.SetCreateFn<NcclLogicalAllGatherKernel>()
.SetIsMatchedHob(user_op::HobDeviceTag() == "gpu");

#define REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(dtype) \
REGISTER_USER_KERNEL("_nccl_logical_all_gather_noncontinuous") \
.SetCreateFn<NcclLogicalAllGatherNoncontinuous<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") \
& (user_op::HobDataType("in", 0) == GetDataType<dtype>::value) \
& (user_op::HobDataType("out", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn(InferAllGatherNoncontinuousKernelTmpBufferSize);

REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int8_t)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int32_t)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int64_t)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(double)
REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float16)

#define REGISTER_S2S_KERNEL(dtype) \
REGISTER_USER_KERNEL("_nccl_logical_s2s") \
.SetCreateFn<NcclLogicalS2SKernel<dtype>>() \
Expand Down
40 changes: 40 additions & 0 deletions oneflow/user/ops/nccl_logical_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,46 @@ REGISTER_USER_OP("_nccl_logical_all_gather")
return Maybe<void>::Ok();
});

REGISTER_USER_OP("_nccl_logical_all_gather_noncontinuous")
.Input("in")
.Output("out")
.Attr<int64_t>("in_split_axis", -1)
.SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->Shape4ArgNameAndIndex("out", 0) = *ctx->Shape4ArgNameAndIndex("in", 0);
*ctx->IsDynamic4ArgNameAndIndex("out", 0) = *ctx->IsDynamic4ArgNameAndIndex("in", 0);
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->Dtype4ArgNameAndIndex("out", 0) = *ctx->Dtype4ArgNameAndIndex("in", 0);
return Maybe<void>::Ok();
})
.SetParallelDistributionInferFn([](user_op::InferParallelDistributionFnContext* ctx)
-> Maybe<void> {
const ParallelDistribution& in_dis_hint =
ctx->ParallelDistributionHint4InputArgNameAndIndex("in", 0);
CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1);
const int64_t in_split_axis = ctx->user_op_conf().attr<int64_t>("in_split_axis");
CHECK_GE_OR_RETURN(in_split_axis, 1);
for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) {
CHECK_OR_RETURN(sbp_hint.has_split_parallel());
CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis);
}

ParallelDistribution* in_distribution = ctx->ParallelDistribution4ArgNameAndIndex("in", 0);
ParallelDistribution* out_distribution = ctx->ParallelDistribution4ArgNameAndIndex("out", 0);
in_distribution->clear_sbp_parallel();
out_distribution->clear_sbp_parallel();

// S(1)->(B)
const Shape& parallel_hierarchy = ctx->parallel_hierarchy();
CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1);
for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {
in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis);
out_distribution->add_sbp_parallel()->mutable_broadcast_parallel();
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP("_nccl_logical_s2s")
.Input("in")
.Output("out")
Expand Down