-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add nccl logical 1d P to S(i) #8361
Conversation
…add_reduce_scatter_noncontiguous
} else if (CanSplitAtDim(0) | ||
&& (src_sbp.has_partial_sum_parallel() && dst_sbp.has_split_parallel()) | ||
&& (dst_sbp.split_parallel().axis() > 0)) { | ||
// P->S(0) : ReduceScatter Noncontinuous |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P->S(1)
REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int8_t) | ||
REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int32_t) | ||
REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int64_t) | ||
REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要支持 bool 类型
transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr()); | ||
|
||
OF_NCCL_CHECK(ncclReduceScatter(tmp_buffer->dptr(), out->mut_dptr(), out->shape().elem_cnt(), | ||
GetNcclDataType(in->data_type()), ncclRedOp_t::ncclSum, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当是 bool 类型时,使用 ncclMax
….com/Oneflow-Inc/oneflow into dev_add_reduce_scatter_noncontiguous
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8361/ |
@@ -218,6 +218,20 @@ bool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp, | |||
.Build() | |||
.op_conf(); | |||
return true; | |||
} else if (CanSplitAtDim(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里写错了吧? 应该是:
CanSplitAtDim(dst_sbp.split_parallel().axis())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哦哦哦,是的,我改一下
….com/Oneflow-Inc/oneflow into dev_add_reduce_scatter_noncontiguous
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8361/ |
Speed stats:
|
fix: https://github.com/Oneflow-Inc/OneTeam/issues/1440