Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev regst coloring #965

Merged
merged 52 commits into from
Jul 1, 2018
Merged

Dev regst coloring #965

merged 52 commits into from
Jul 1, 2018

Conversation

lixinqi
Copy link
Contributor

@lixinqi lixinqi commented Jun 26, 2018

No description provided.

lixinqi added 30 commits June 4, 2018 21:31
lixinqi and others added 4 commits June 22, 2018 17:39
* is head regst desc where sharing same mem

* fix typo

* memory sharing info

* interface: add ctrl regst for mem_sharing critical section

* find shared mem header to sink and add ctrl regst

* refine MakeSetterAddCtrlRegst

* move out of loop

* pass RegstDescProto* vector instead

* move loop out and extract BuildCtrlRegst

* use shared_ptr instead of unique_ptr

* find or create

* refine parameters

* directly connected tasks have no need to add guard ctrl regst

* modify because of mem_sharing_info changing from optional to required

* refine code

* fix bug

* rename

* extract TryConnectWithCtrlRegstDesc and InitCtrlRegstDesc

* rename

* rename

* FindOrCreate from improve to task_node

* rename TryConnectWithMemSafeGuardCtrlRegstDesc parameters

* check mem_sharing_info.enable_mem_sharing

* extract IsInRepeatedField

* at(0) to front()

* refine IsInRepeatedField

* use single template argument T

* find all shared mem regst sink comsumer task

* adjust semantic

* ref capture IsReachable

* task_ids captured by ref

* remove useless function
@lixinqi lixinqi requested a review from willzhang4a58 as a code owner June 26, 2018 08:33
@@ -8,6 +8,8 @@ namespace oneflow {

const int32_t kMaxRegisterNum = std::numeric_limits<int32_t>::max();

void InitCtrlRegstDesc(int64_t produced_task_id, RegstDescProto* ctrl_regst_proto);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

produced_task_id -> producer_task_id

@@ -146,4 +146,19 @@ bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) {
return true;
}

void InitCtrlRegstDesc(int64_t produced_task_id, RegstDescProto* ctrl_regst_proto) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

produced_task_id -> producer_task_id

Plan Improve(const Plan& naive_plan, const std::string& act_event_filepath);
Plan Improve(const AvailableMemDesc& amd, const Plan& naive_plan,
const std::string& act_event_filepath);
Plan ImproveMemSharedInfoOnly(const Plan& naive_plan) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么有的地方是mem shared info, 有的地方是mem sharing info

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,没有及时改过来

&& regst_desc.mem_sharing_info().enable_mem_sharing();
}

bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

比较绕,这个不用搞一个集合去判断吧,直接判断是否所有的consumer的 chain id 和 producer chain id 是否相等就可以了吧。

const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
HashMap<int64_t, std::list<const TaskProto*>> chain_id2task_proto;
for (const TaskProto& task : plan.task()) {
if (Global<IDMgr>::Get()->LocalWorkStreamId4TaskId(task.task_id()) != 0) { continue; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是为什么? 是只处理每个thread上的第一个stream上的task? 也就是不处理那种independent stream?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

independent stream上的内存共享开发人员手动设置,不在这推断

ForEachSameColoredChainRegstDescWithConsumer(graph, plan, [&](const RegstDescs& regst_descs) {
int32_t used_order_value = 0;
mem_sharing_info.set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId());
graph.SortByProducerTaskTopoOrder(regst_descs, [&](const RegstDescProto* regst_desc) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这里的逻辑可以简化:利用chain_id和order_in_graph那个值。

});
}

void PlanTaskGraph::AssertThereIsOnlyOneTopoOrder(const HashSet<int64_t>& task_ids) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在plan graph里任何时候不要自己做拓扑遍历,全部使用order in graph规定的顺序来做,自然保证永远都是一个顺序。

});
}

void PlanTaskGraph::ComputeLifetimeActorIds(const RegstDescProto* regst_desc,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也不需要再拓扑遍历了。

});
}

void ForEachSameColoredChainRegstDescWithConsumer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分代码callback嵌套的太多了,逻辑分散在各处,不容易读懂。

const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
HashMap<int64_t, std::list<const RegstDescProto*>> global_work_stream_id2regst_descs;
for (const auto& task : plan.task()) {
if (Global<IDMgr>::Get()->LocalWorkStreamId4TaskId(task.task_id()) != 0) { continue; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方也通过在TaskNode set regst sharable false来禁用共享。

for (const auto& pair : task_id2intersected_nodes) {
for (RegstLifetimeNode* src_node : pair.second) {
for (RegstLifetimeNode* dst_node : pair.second) {
if (src_node < dst_node) { src_node2dst_nodes[src_node].emplace(dst_node); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方根据指针大小来排列节点,运行多次的话结果不确定。可以按照RegstLifetimeNode->regst->producer task node 的order_in_graph来排序,这样每次运行结果都一样。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不好,因为src_node和dst_node可能在一个task上。
换成regst_desc_id就行

@yuanms2 yuanms2 merged commit a36c035 into master Jul 1, 2018
@lixinqi lixinqi deleted the dev_regst_coloring branch July 2, 2018 06:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants