Skip to content

Commit

Permalink
handle ctrl msg from other rank (#4491)
Browse files Browse the repository at this point in the history
* handle ctrl msg from other rank

* static way

* use check

* add todo

* add CHECK

* add DumpToConsumedRegstDescId2Addr function

* fix comment

* handle returned_regst_num

* optimize code

* rename regst_desc_id2regst_desc_addr_

* fix Segfault fault

* remove returned_regst_num

* rename arg and function

* add info im plan

* use name producer_task_id

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
clackhan and oneflow-ci-bot authored Mar 25, 2021
1 parent 3d84630 commit 4380494
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 21 deletions.
54 changes: 36 additions & 18 deletions oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,13 @@ int64_t Actor::GetPieceId4NaiveCurReadableDataRegst() const {
int64_t init_val = -2;
int64_t pid = init_val;
auto FirstFoundOnly = [&pid, init_val](int64_t) { return pid == init_val; };
naive_consumed_rs_.ForChosenFrontRegst(FirstFoundOnly, [&pid](Regst* regst) {
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { pid = regst->piece_id(); }
});
naive_consumed_rs_.ForChosenFrontRegst(
FirstFoundOnly, [&pid](int64_t regst_desc_id, Regst* regst) {
if (Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { return; }
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {
pid = regst->piece_id();
}
});
CHECK_GE(pid, 0);
return pid;
}
Expand All @@ -248,7 +252,8 @@ int64_t Actor::GetPieceId4NaiveOrInplaceCurReadableDataRegst() const {
int64_t init_val = -2;
int64_t pid = init_val;
auto FirstFoundOnly = [&pid, init_val](int64_t) { return pid == init_val; };
auto Select = [&pid](Regst* regst) {
auto Select = [&pid](int64_t regst_desc_id, Regst* regst) {
if (Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { return; }
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { pid = regst->piece_id(); }
};
naive_consumed_rs_.ForChosenFrontRegst(FirstFoundOnly, Select);
Expand All @@ -274,7 +279,8 @@ void Actor::SetReadableRegstInfo(const Regst* regst, ReadableRegstInfo* info) co
}

void Actor::ForEachCurNaiveReadableDataRegst(std::function<void(const Regst*)> func) const {
naive_consumed_rs_.ForEachFrontRegst([func](Regst* regst) {
naive_consumed_rs_.ForEachFrontRegst([func](int64_t regst_desc_id, Regst* regst) {
if (Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { return; }
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { func(regst); }
});
}
Expand Down Expand Up @@ -316,7 +322,17 @@ int Actor::HandlerNormal(const ActorMsg& msg) {
}
} else {
if (NormalTryProcessReadableMsgFromOtherMachine(msg) == false) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst()), 0);
// process ctrl msg from other rank
if (IsConsumedCtrlRegstDescId(msg.regst_desc_id())) {
Regst* regst = msg.regst();
CHECK(naive_consumed_rs_.HasRegstDescId(msg.regst_desc_id()));
CHECK(Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(msg.regst_desc_id()));
CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst, msg.regst_desc_id()));
const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(msg.regst_desc_id());
CHECK(rdeq.empty() == false);
} else {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst()), 0);
}
}
}
ActUntilFail();
Expand Down Expand Up @@ -374,7 +390,8 @@ void Actor::TryLogActEvent(const std::function<void()>& DoAct) const {
act_event->set_work_stream_id(GetGlobalWorkStreamId());
act_event->set_act_id(act_id_);
act_event->set_ready_time(GetCurTime());
naive_consumed_rs_.ForEachFrontRegst([&](const Regst* readable_regst) {
naive_consumed_rs_.ForEachFrontRegst([&](int64_t regst_desc_id, const Regst* readable_regst) {
if (Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { return; }
ReadableRegstInfo* info = act_event->add_readable_regst_infos();
Actor::SetReadableRegstInfo(readable_regst, info);
});
Expand Down Expand Up @@ -465,16 +482,16 @@ void Actor::AsyncSendConsumedCtrlRegstMsgToProducer() {
};

tmp_regst_desc_id_vec_.clear();
naive_consumed_rs_.ForChosenRegstDeq(IsChosenRegstDescId, [&](const std::deque<Regst*>& reg_deq) {
CHECK(reg_deq.empty() == false);
Regst* regst = reg_deq.front();
CHECK(regst->regst_desc()->regst_desc_type().has_ctrl_regst_desc());
CHECK_GE(reg_deq.size(), 1);
// must access regst before sending it to producer
tmp_regst_desc_id_vec_.push_back(regst->regst_desc_id());
EnqueueAsyncMsg(
ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst));
});
naive_consumed_rs_.ForChosenRegstDeq(
IsChosenRegstDescId, [&](int64_t regst_desc_id, const std::deque<Regst*>& reg_deq) {
CHECK(reg_deq.empty() == false);
auto producer_task_id = Global<RegstMgr>::Get()->ProducerTaskId4RegstDescId(regst_desc_id);
Regst* regst = reg_deq.front();
CHECK_GE(reg_deq.size(), 1);
// must access regst before sending it to producer
tmp_regst_desc_id_vec_.push_back(regst_desc_id);
EnqueueAsyncMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, producer_task_id, regst));
});
naive_consumed_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);
}

Expand Down Expand Up @@ -609,7 +626,8 @@ void Actor::AsyncSendRegstMsgToConsumer(Regst* regst, std::function<bool(int64_t

void Actor::HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst) {
tmp_regst_desc_id_vec_.clear();
naive_consumed_rs_.ForEachFrontRegst([&](Regst* regst) {
naive_consumed_rs_.ForEachFrontRegst([&](int64_t regst_desc_id, Regst* regst) {
if (IsConsumedCtrlRegstDescId(regst_desc_id)) { return; }
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {
if (IsAllowedRegst(regst) == false) { return; }
// must access regst before sending it to producer
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/actor/actor_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ ActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer,
msg.regst_wrapper_.comm_net_token = regst_raw_ptr->comm_net_token();
}
msg.regst_wrapper_.regst_status = regst_raw_ptr->status();
msg.regst_wrapper_.regst_status.regst_desc_id = regst_raw_ptr->regst_desc_id();
msg.regst_wrapper_.has_sole_empty_blob = IsSoleBlobAndDynamicEmpty(regst_raw_ptr);
return msg;
}
Expand Down
30 changes: 29 additions & 1 deletion oneflow/core/actor/register_slot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ const std::deque<Regst*>& RegstSlot::RegstDeq4RegstDescId(int64_t regst_desc_id)
}

int RegstSlot::TryPushBackRegst(Regst* regst) {
return TryPushBackRegst(regst, regst->regst_desc_id());
}

int RegstSlot::TryPushBackRegst(Regst* regst, int64_t regst_desc_id) {
CHECK(is_inited_);
auto it = regst_desc_id2regsts_.find(regst->regst_desc_id());
auto it = regst_desc_id2regsts_.find(regst_desc_id);
if (it == regst_desc_id2regsts_.end()) { return -1; }
if (it->second.empty()) { available_regst_desc_cnt_ += 1; }
it->second.push_back(regst);
Expand Down Expand Up @@ -95,17 +99,41 @@ void RegstSlot::ForChosenFrontRegst(std::function<bool(int64_t)> IsChosenRegstDe
}
}

void RegstSlot::ForChosenFrontRegst(
std::function<bool(int64_t)> IsChosenRegstDescId,
std::function<void(int64_t regst_desc_id, Regst*)> Handler) const {
for (const auto& kv : regst_desc_id2regsts_) {
if (IsChosenRegstDescId(kv.first)) {
CHECK(kv.second.empty() == false);
Handler(kv.first, kv.second.front());
}
}
}

void RegstSlot::ForChosenRegstDeq(std::function<bool(int64_t)> IsChosenRegstDescId,
std::function<void(const std::deque<Regst*>&)> Handler) const {
for (const auto& kv : regst_desc_id2regsts_) {
if (IsChosenRegstDescId(kv.first)) { Handler(kv.second); }
}
}

void RegstSlot::ForChosenRegstDeq(
std::function<bool(int64_t)> IsChosenRegstDescId,
std::function<void(int64_t regst_desc_id, const std::deque<Regst*>&)> Handler) const {
for (const auto& kv : regst_desc_id2regsts_) {
if (IsChosenRegstDescId(kv.first)) { Handler(kv.first, kv.second); }
}
}

void RegstSlot::ForEachFrontRegst(std::function<void(Regst*)> Handler) const {
ForChosenFrontRegst([](int64_t) { return true; }, Handler);
}

void RegstSlot::ForEachFrontRegst(
std::function<void(int64_t regst_desc_id, Regst*)> Handler) const {
ForChosenFrontRegst([](int64_t) { return true; }, Handler);
}

void RegstSlot::ForEachRegstDeq(std::function<void(const std::deque<Regst*>&)> Handler) const {
ForChosenRegstDeq([](int64_t) { return true; }, Handler);
}
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/actor/register_slot.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,24 @@ class RegstSlot final {
bool HasRegstDescId(int64_t regst_desc_id) const;
const std::deque<Regst*>& RegstDeq4RegstDescId(int64_t regst_desc_id) const;
void ForEachFrontRegst(std::function<void(Regst*)>) const;
void ForEachFrontRegst(std::function<void(int64_t regst_desc_id, Regst*)>) const;
void ForEachRegstDeq(std::function<void(const std::deque<Regst*>&)>) const;
void ForChosenFrontRegst(std::function<bool(int64_t)>, std::function<void(Regst*)>) const;
void ForChosenFrontRegst(std::function<bool(int64_t)>,
std::function<void(int64_t regst_desc_id, Regst*)>) const;
void ForChosenRegstDeq(std::function<bool(int64_t)>,
std::function<void(const std::deque<Regst*>&)>) const;
void ForChosenRegstDeq(
std::function<bool(int64_t)>,
std::function<void(int64_t regst_desc_id, const std::deque<Regst*>&)>) const;

Regst* Front(int64_t regst_desc_id) const;
Regst* SoleFront() const;
Regst* FirstFront() const;

// 0: success, -1: cannot find regst_desc_id
int TryPushBackRegst(Regst* regst);
int TryPushBackRegst(Regst* regst, int64_t regst_desc_id);
int TryPopFrontRegst(int64_t regst_desc_id);

void PopFrontRegsts(const std::vector<int64_t>& regst_desc_ids);
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/graph/task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ void TaskNode::ToProto(TaskProto* task_proto) {
task_proto->mutable_task_set_info()->set_chain_id(chain_id_);
task_proto->mutable_task_set_info()->set_order_in_graph(order_in_graph_);
exec_gph_.ToExecSequence(parallel_ctx(), task_proto->mutable_exec_sequence());
auto produced_regst_proto = task_proto->mutable_produced_regst_desc();
auto* produced_regst_proto = task_proto->mutable_produced_regst_desc();
for (auto& pair : produced_regsts_) {
RegstDescProto regst_desc_proto;
pair.second->ToProto(&regst_desc_proto);
CHECK(produced_regst_proto->insert({pair.first, regst_desc_proto}).second);
}
auto consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id();
auto* consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id();
for (const auto& pair : consumed_regsts_) {
RegstDescIdSet regst_desc_ids;
for (const std::shared_ptr<RegstDesc>& regst : pair.second) {
Expand Down
23 changes: 23 additions & 0 deletions oneflow/core/job/oneflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ std::string cluster_thrd_ids_key(const std::string& plan_name) {

std::string net_topo_key(const std::string& plan_name) { return plan_name + "_net_topo"; }

std::string ctrl_regst_desc_info_key(const std::string& plan_name) {
return plan_name + "_ctrl_regst_desc_info_key";
}

std::string job_id2job_conf(const std::string& plan_name) { return plan_name + "_job_id2job_conf"; }

std::string GetCollectiveBoxingPlanKey(const std::string& plan_name) {
Expand Down Expand Up @@ -139,6 +143,8 @@ void PushPlan(const std::string& plan_name, const Plan& plan) {
}

Global<CtrlClient>::Get()->PushKV(net_topo_key(plan_name), plan.net_topo());
Global<CtrlClient>::Get()->PushKV(ctrl_regst_desc_info_key(plan_name),
plan.ctrl_regst_desc_info());
Global<CtrlClient>::Get()->PushKV(job_id2job_conf(plan_name), plan.job_confs());
Global<CtrlClient>::Get()->PushKV(GetCollectiveBoxingPlanKey(plan_name),
plan.collective_boxing_plan());
Expand All @@ -162,6 +168,9 @@ void PullPlan(const std::string& plan_name, Plan* plan) {
NetTopo net_topo;
Global<CtrlClient>::Get()->PullKV(net_topo_key(plan_name), &net_topo);
*(plan->mutable_net_topo()) = net_topo;
CtrlRegstDescInfo ctrl_regst_desc_info;
Global<CtrlClient>::Get()->PullKV(ctrl_regst_desc_info_key(plan_name), &ctrl_regst_desc_info);
*(plan->mutable_ctrl_regst_desc_info()) = ctrl_regst_desc_info;
JobConfs job_confs;
Global<CtrlClient>::Get()->PullKV(job_id2job_conf(plan_name), &job_confs);
*(plan->mutable_job_confs()) = job_confs;
Expand Down Expand Up @@ -350,6 +359,19 @@ void MergePlan(Plan* plan, const Plan& other) {
Compiler().GenNetTopo(plan);
}

void DumpCtrlRegstInfoToPlan(Plan* plan) {
auto* ctrl_regst_desc_id2producer_task_id =
plan->mutable_ctrl_regst_desc_info()->mutable_ctrl_regst_desc_id2producer_task_id();
for (const TaskProto& task : plan->task()) {
for (const auto& pair : task.produced_regst_desc()) {
if (pair.second.regst_desc_type().has_ctrl_regst_desc()) {
ctrl_regst_desc_id2producer_task_id->insert(
{pair.second.regst_desc_id(), pair.second.producer_task_id()});
}
}
}
}

RegstDescProto* GetSoleDataRegstDescProto(TaskProto* task) {
RegstDescProto* ret = nullptr;
for (auto& pair : *task->mutable_produced_regst_desc()) {
Expand Down Expand Up @@ -1094,6 +1116,7 @@ Maybe<void> CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan)
}
LinkMainPlan(plan, main_plan, identity_tick_op_names);
PlanUtil::CleanUselessMemBlockAndCheckValid(plan);
DumpCtrlRegstInfoToPlan(plan);
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
TeePersistentLogStream::Create("merged_plan")->Write(*plan);
PlanUtil::ToDotFile(*plan, "/dot/merged_plan.dot");
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/job/plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ message CollectiveBoxingPlan {
map<int64, boxing.collective.RequestSet> job_id2request_set = 1;
}

message CtrlRegstDescInfo {
map<int64, int64> ctrl_regst_desc_id2producer_task_id = 6;
}

message Plan {
repeated TaskProto task = 1;
required MemBlockAndChunkList block_chunk_list = 2;
required NetTopo net_topo = 3;
required JobConfs job_confs = 4;
required CollectiveBoxingPlan collective_boxing_plan= 5;
required CtrlRegstDescInfo ctrl_regst_desc_info = 6;
}
14 changes: 14 additions & 0 deletions oneflow/core/register/register_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ RegstMgr::RegstMgr(const Plan& plan) {
CHECK(regst_desc_id2parallel_ctx_.emplace(regst_desc_id, task.parallel_ctx()).second);
}
}
for (const auto& pair : plan.ctrl_regst_desc_info().ctrl_regst_desc_id2producer_task_id()) {
CHECK(ctrl_regst_desc_id2producer_task_id_.emplace(pair.first, pair.second).second);
}
}

void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
Expand Down Expand Up @@ -222,6 +225,17 @@ bool RegstMgr::HasRegstDescId(int64_t regst_desc_id) const {
return regst_desc_id2rt_regst_desc_.find(regst_desc_id) != regst_desc_id2rt_regst_desc_.end();
}

int64_t RegstMgr::ProducerTaskId4RegstDescId(int64_t regst_desc_id) const {
const auto& it = ctrl_regst_desc_id2producer_task_id_.find(regst_desc_id);
CHECK(it != ctrl_regst_desc_id2producer_task_id_.end());
return it->second;
}

bool RegstMgr::HasProducerTaskId4RegstDescId(int64_t regst_desc_id) const {
return ctrl_regst_desc_id2producer_task_id_.find(regst_desc_id)
!= ctrl_regst_desc_id2producer_task_id_.end();
}

Blob* RegstMgr::Blob4LbiAndParallelId(const LogicalBlobId& lbi, const int64_t parallel_id) {
return lbi2parallel_id2blob_.at(lbi).at(parallel_id);
}
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/register/register_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class RegstMgr final {
void NewRegsts(const RegstDescProto& regst_desc_proto, std::function<void(Regst*)> OneRegstDone);
const RtRegstDesc& RegstDesc4RegstDescId(int64_t regst_desc_id) const;
bool HasRegstDescId(int64_t regst_desc_id) const;
int64_t ProducerTaskId4RegstDescId(int64_t regst_desc_id) const;
bool HasProducerTaskId4RegstDescId(int64_t regst_desc_id) const;
Blob* Blob4LbiAndParallelId(const LogicalBlobId& lbi, const int64_t parallel_id);

private:
Expand All @@ -50,6 +52,7 @@ class RegstMgr final {
HashMap<LogicalBlobId, HashMap<int64_t, Blob*>> lbi2parallel_id2blob_;
HashMap<int64_t, char*> mem_block_id2ptr_;
HashMap<int64_t, ParallelContext> regst_desc_id2parallel_ctx_;
HashMap<int64_t, int64_t> ctrl_regst_desc_id2producer_task_id_;
std::mutex mutex_;
};

Expand Down
50 changes: 50 additions & 0 deletions oneflow/python/test/ops/test_two_node_boxing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
import oneflow.typing as oft
from typing import Tuple
import time


@flow.unittest.skip_unless_2n1d()
class TestTwoNodeBoxing(flow.unittest.TestCase):
def test_two_node_boardcast(test_case):
flow.clear_default_session()
flow.config.enable_debug_mode(True)
flow.config.gpu_device_num(4)
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.default_logical_view(flow.scope.consistent_view())

@flow.global_function(function_config=func_config)
def split_to_broadcast_job(input_blob: oft.Numpy.Placeholder((96, 96))):
with flow.scope.placement("gpu", "0:0"):
src = flow.identity(
input_blob.with_distribute(flow.distribute.split(0))
)
with flow.scope.placement("gpu", ["0:0", "1:0"]):
dst = flow.identity(src.with_distribute(flow.distribute.broadcast()))
return dst

x = np.random.rand(96, 96).astype(np.float32)
result = split_to_broadcast_job(x).get()
test_case.assertTrue(np.array_equal(x, result.numpy()))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4380494

Please sign in to comment.