-
Notifications
You must be signed in to change notification settings - Fork 842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev regst coloring #965
Dev regst coloring #965
Conversation
…IsOnlyOneTopoOrder
…puteSameChainActorIds
* is head regst desc where sharing same mem * fix typo * memory sharing info * interface: add ctrl regst for mem_sharing critical section * find shared mem header to sink and add ctrl regst * refine MakeSetterAddCtrlRegst * move out of loop * pass RegstDescProto* vector instead * move loop out and extract BuildCtrlRegst * use shared_ptr instead of unique_ptr * find or create * refine parameters * directly connected tasks have no need to add guard ctrl regst * modify because of mem_sharing_info changing from optional to required * refine code * fix bug * rename * extract TryConnectWithCtrlRegstDesc and InitCtrlRegstDesc * rename * rename * FindOrCreate from improve to task_node * rename TryConnectWithMemSafeGuardCtrlRegstDesc parameters * check mem_sharing_info.enable_mem_sharing * extract IsInRepeatedField * at(0) to front() * refine IsInRepeatedField * use single template argument T * find all shared mem regst sink comsumer task * adjust semantic * ref capture IsReachable * task_ids captured by ref * remove useless function
…nto dev_regst_coloring
@@ -8,6 +8,8 @@ namespace oneflow { | |||
|
|||
const int32_t kMaxRegisterNum = std::numeric_limits<int32_t>::max(); | |||
|
|||
void InitCtrlRegstDesc(int64_t produced_task_id, RegstDescProto* ctrl_regst_proto); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
produced_task_id -> producer_task_id
@@ -146,4 +146,19 @@ bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) { | |||
return true; | |||
} | |||
|
|||
void InitCtrlRegstDesc(int64_t produced_task_id, RegstDescProto* ctrl_regst_proto) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
produced_task_id -> producer_task_id
oneflow/core/job/improver.h
Outdated
Plan Improve(const Plan& naive_plan, const std::string& act_event_filepath); | ||
Plan Improve(const AvailableMemDesc& amd, const Plan& naive_plan, | ||
const std::string& act_event_filepath); | ||
Plan ImproveMemSharedInfoOnly(const Plan& naive_plan) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么有的地方是mem shared info, 有的地方是mem sharing info
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,没有及时改过来
oneflow/core/job/improver.cpp
Outdated
&& regst_desc.mem_sharing_info().enable_mem_sharing(); | ||
} | ||
|
||
bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
比较绕,这个不用搞一个集合去判断吧,直接判断是否所有的consumer的 chain id 和 producer chain id 是否相等就可以了吧。
oneflow/core/job/improver.cpp
Outdated
const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) { | ||
HashMap<int64_t, std::list<const TaskProto*>> chain_id2task_proto; | ||
for (const TaskProto& task : plan.task()) { | ||
if (Global<IDMgr>::Get()->LocalWorkStreamId4TaskId(task.task_id()) != 0) { continue; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是为什么? 是只处理每个thread上的第一个stream上的task? 也就是不处理那种independent stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
independent stream上的内存共享开发人员手动设置,不在这推断
oneflow/core/job/improver.cpp
Outdated
ForEachSameColoredChainRegstDescWithConsumer(graph, plan, [&](const RegstDescs& regst_descs) { | ||
int32_t used_order_value = 0; | ||
mem_sharing_info.set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId()); | ||
graph.SortByProducerTaskTopoOrder(regst_descs, [&](const RegstDescProto* regst_desc) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这里的逻辑可以简化:利用chain_id和order_in_graph那个值。
}); | ||
} | ||
|
||
void PlanTaskGraph::AssertThereIsOnlyOneTopoOrder(const HashSet<int64_t>& task_ids) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在plan graph里任何时候不要自己做拓扑遍历,全部使用order in graph规定的顺序来做,自然保证永远都是一个顺序。
}); | ||
} | ||
|
||
void PlanTaskGraph::ComputeLifetimeActorIds(const RegstDescProto* regst_desc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也不需要再拓扑遍历了。
}); | ||
} | ||
|
||
void ForEachSameColoredChainRegstDescWithConsumer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分代码callback嵌套的太多了,逻辑分散在各处,不容易读懂。
oneflow/core/job/improver.cpp
Outdated
const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) { | ||
HashMap<int64_t, std::list<const RegstDescProto*>> global_work_stream_id2regst_descs; | ||
for (const auto& task : plan.task()) { | ||
if (Global<IDMgr>::Get()->LocalWorkStreamId4TaskId(task.task_id()) != 0) { continue; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方也通过在TaskNode set regst sharable false来禁用共享。
for (const auto& pair : task_id2intersected_nodes) { | ||
for (RegstLifetimeNode* src_node : pair.second) { | ||
for (RegstLifetimeNode* dst_node : pair.second) { | ||
if (src_node < dst_node) { src_node2dst_nodes[src_node].emplace(dst_node); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方根据指针大小来排列节点,运行多次的话结果不确定。可以按照RegstLifetimeNode->regst->producer task node 的order_in_graph来排序,这样每次运行结果都一样。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不好,因为src_node和dst_node可能在一个task上。
换成regst_desc_id就行
No description provided.