-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
new_X_to_B #5987
new_X_to_B #5987
Conversation
@@ -33,6 +34,7 @@ Maybe<one::Tensor> EagerBoxingInterpreter::Interpret(const std::shared_ptr<one:: | |||
Symbol<ParallelDesc> in_parallel_desc, | |||
Symbol<ParallelDesc> out_parallel_desc) const { | |||
JUST(CheckEagerBoxingDataType(input->dtype()->data_type())); | |||
DisableCheckConsistentTensorMetaScope disable_meta_check; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
禁止interpreter中做检查ConsistentTensorMeta
&& (in_parallel_desc->device_type() == DeviceType::kGPU | ||
&& out_parallel_desc->device_type() == DeviceType::kGPU)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前仅支持gpu版本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么仅支持gpu版本呢?这里只用到了broadcast op。而这个op cpu下也有啊。
Maybe<int64_t> GetBroadcastRoot(Symbol<ParallelDesc> src_parallel_desc, | ||
Symbol<ParallelDesc> dst_parallel_desc) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据src placment和dst placement计算broadcast过程中的root节点
const auto& new_tag_in_parallel_desc = | ||
JUST(ReplaceDeviceType(in_parallel_desc, out_parallel_desc->device_type())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device type不同时,第一步转换为sbp为broadcast的tensor时需要考虑device tag的变化
std::shared_ptr<one::Tensor> local_tensor = JUST(broadcast_input->cur_rank_phy_tensor()); | ||
{ | ||
const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc)); | ||
if (out_parallel_id->has_value()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只有out_parallel_desc覆盖的rank才会执行broadcast
if (!new_in_parallel_id->has_value()) { | ||
std::string device_type = Device::Type4DeviceTag(new_tag_in_parallel_desc->device_tag()); | ||
local_tensor = JUST(one::functional::Empty(*input->shape(), input->dtype(), | ||
JUST(Device::New(device_type)))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输入tensor在当前rank无效时,需要创建empty tensor,主要作用是完成output的推导
Symbol<ParallelDesc> broadcast_parallel_desc_cur_rank = | ||
JUST(MapAt(*broadcast_grop, GlobalProcessCtx::Rank())); | ||
int64_t root = | ||
JUST(CachedGetBroadcastRoot(new_tag_in_parallel_desc, broadcast_parallel_desc_cur_rank)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
计算各组broadcast 所需的placement和root
int64_t dev_id = GlobalProcessCtx::LocalRank(root); | ||
int64_t parallel_id = | ||
CHECK_JUST(kernel_state->parallel_desc()->ParallelId4MachineDeviceId(root, dev_id)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nccl中,需要根据root计算communicator中对应的device rank
Maybe<int64_t> CalBroadcastRoot(Symbol<ParallelDesc> src_parallel_desc, | ||
Symbol<ParallelDesc> dst_parallel_desc) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据src_parallel_desc和dst_parallel_desc计算broadcast过程中的root节点
@@ -99,15 +99,18 @@ class EagerNcclBroadcastKernel final : public user_op::OpKernel { | |||
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); | |||
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); | |||
int64_t root = ctx->Attr<int64_t>("root"); | |||
int64_t dev_id = GlobalProcessCtx::LocalRank(root); | |||
int64_t parallel_id = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nccl_root
@@ -116,6 +117,13 @@ Maybe<EagerBoxingInterpreter> GetBoxingInterpreter(Symbol<cfg::NdSbp> in_nd_sbp, | |||
in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc)); | |||
if (interpreter.IsOk()) { return JUST(interpreter); } | |||
} | |||
if (in_parallel_desc->parallel_num() != out_parallel_desc->parallel_num() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把这段逻辑写到GetOneDimNcclCollectiveEagerBoxingInterpreter。这里你就直接支持cpu版本了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这我可能表述还有不太对。不见得立刻能支持cpu。
…new_X_to_B Conflicts: oneflow/core/framework/op_interpreter/boxing/eager_boxing_interpreter_mgr.cpp
…nto new_X_to_B
if (!out_parallel_id->has_value()) { | ||
std::string device_type = Device::Type4DeviceTag(in_parallel_desc->device_tag()); | ||
local_tensor = JUST(one::functional::Empty( | ||
*JUST(GetPhysicalShape(*input->shape(), *in_nd_sbp, *in_parallel_desc, 0)), input->dtype(), | ||
JUST(Device::New(device_type)))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
placement没有覆盖当前进程时,local tensor为空tensor,则需要给当前local tensor初始化一个tensor,以防止本文件第43行的local to consistent函数中,执行mirror copy op时发生错误
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 local tensor 是什么角色呢,我看计算它的形状的时候,parallel_id 总是 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 local tensor 是什么角色呢,我看计算它的形状的时候,parallel_id 总是 0
这里为了防止43行调用的ToConsistent函数出bug,当tensor的placement没有覆盖当前rank是,取到的local tensor是一个空tensor,在ToConsistent中会执行copy op完成devcie tag的转换,如果输入是一个空tensor,则该进程执行copy op会发生错误,因此需要给local_tensor重新赋值,使之成为一个有意义的tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那为什么计算它的形状的时候 parallel_id 总是 0 呢
Speed stats:
|
oneflow/core/framework/op_interpreter/boxing/eager_boxing_interpreter_mgr.cpp
Outdated
Show resolved
Hide resolved
Speed stats:
|
|
||
static constexpr auto* CheckSymXToB = DECORATE(&RawCheckSymXToB, ThreadLocal); | ||
|
||
Maybe<one::UserOpExpr> EagerNcclAllReduce(Symbol<ParallelDesc> parallel_desc) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oneflow/core/framework/op_interpreter/boxing/collective_boxing_interpreter.cpp 里有一模一样的代码,可以复用吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oneflow/core/framework/op_interpreter/boxing/collective_boxing_interpreter.cpp 里有一模一样的代码,可以复用吗
之后会把所有boxing interpreter改为注册的形式,到时候/collective_boxing_interpreter.cpp文件会删除,现在先保留
oneflow/core/framework/op_interpreter/boxing/asymmetric_broadcast.cpp
Outdated
Show resolved
Hide resolved
oneflow/core/framework/op_interpreter/boxing/asymmetric_broadcast.cpp
Outdated
Show resolved
Hide resolved
if (!out_parallel_id->has_value()) { | ||
std::string device_type = Device::Type4DeviceTag(in_parallel_desc->device_tag()); | ||
local_tensor = JUST(one::functional::Empty( | ||
*JUST(GetPhysicalShape(*input->shape(), *in_nd_sbp, *in_parallel_desc, 0)), input->dtype(), | ||
JUST(Device::New(device_type)))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 local tensor 是什么角色呢,我看计算它的形状的时候,parallel_id 总是 0
JUST(SymXToBBoxingFunction(tensor, in, broadcast_in_placed_nd_sbp)); | ||
|
||
const auto& AsymBoxingFunction = | ||
*JUST(GetBoxingFunction("asymmetric-x-to-b", broadcast_in_placed_nd_sbp, out)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么会有 asymmetric-x-to-b 和 asym-x-to-b 两个 boxing function 呢,是不是名字错了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么会有 asymmetric-x-to-b 和 asym-x-to-b 两个 boxing function 呢,是不是名字错了
笔误写错了,是asymmetric-broadcast,已更正
Speed stats:
|
No description provided.