Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace piece_id with comm_net_sequence_number #5731

Merged
merged 5 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions oneflow/core/actor/acc_compute_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class AccCompActor final : public CompActor {
std::function<void(DeviceCtx*, void* dst, const void* src, size_t)> cpy_func_;
int32_t acc_cnt_;
int32_t max_acc_cnt_;
int64_t next_piece_id_;
};

void AccCompActor::VirtualCompActorInit(const TaskProto& proto) {
Expand Down Expand Up @@ -62,7 +61,6 @@ void AccCompActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt) {
OF_SET_MSG_HANDLER(&AccCompActor::HandlerNormal);
acc_cnt_ = 0;
max_acc_cnt_ = max_acc_cnt;
next_piece_id_ = 0;
}

int64_t AccCompActor::ActNumForEachOutput(int64_t regst_desc_id) const {
Expand Down Expand Up @@ -97,12 +95,8 @@ void AccCompActor::Act() {

void AccCompActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {
if (acc_cnt_ == max_acc_cnt_) {
HandleProducedNaiveDataRegstToConsumer([&](Regst* regst) {
regst->set_piece_id(next_piece_id_);
return true;
});
HandleProducedNaiveDataRegstToConsumer();
acc_cnt_ = 0;
next_piece_id_ += 1;
}
}

Expand Down
84 changes: 8 additions & 76 deletions oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,35 +235,6 @@ void Actor::IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val) {
produced_regst2reading_cnt_.at(regst) += val;
}

int64_t Actor::GetPieceId4NaiveCurReadableDataRegst() const {
int64_t init_val = -2;
int64_t pid = init_val;
auto FirstFoundOnly = [&pid, init_val](int64_t) { return pid == init_val; };
naive_consumed_rs_.ForChosenFrontRegst(
FirstFoundOnly, [&pid](int64_t regst_desc_id, Regst* regst) {
if (Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { return; }
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {
pid = regst->piece_id();
}
});
CHECK_GE(pid, 0);
return pid;
}

int64_t Actor::GetPieceId4NaiveOrInplaceCurReadableDataRegst() const {
int64_t init_val = -2;
int64_t pid = init_val;
auto FirstFoundOnly = [&pid, init_val](int64_t) { return pid == init_val; };
auto Select = [&pid](int64_t regst_desc_id, Regst* regst) {
if (Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { return; }
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { pid = regst->piece_id(); }
};
naive_consumed_rs_.ForChosenFrontRegst(FirstFoundOnly, Select);
if (pid == init_val) { inplace_consumed_rs_.ForChosenFrontRegst(FirstFoundOnly, Select); }
CHECK_GE(pid, 0);
return pid;
}

void Actor::InitDeviceCtx(const ThreadCtx& thread_ctx) {
DeviceCtx* dev_ctx = NewObj<int, DeviceCtx, const ThreadCtx&>(GetDeviceType(), thread_ctx);
device_ctx_.reset(dev_ctx);
Expand Down Expand Up @@ -475,7 +446,7 @@ void Actor::AsyncSendNaiveConsumedRegstMsgToProducer() {
}

void Actor::VirtualAsyncSendNaiveConsumedRegstMsgToProducer() {
HandleConsumedNaiveDataRegstToProducer([](Regst* regst) { return true; });
HandleConsumedNaiveDataRegstToProducer();
}

void Actor::AsyncSendConsumedCtrlRegstMsgToProducer() {
Expand Down Expand Up @@ -505,20 +476,19 @@ void Actor::AsyncSendProducedCtrlRegstMsgToConsumer() {
tmp_regst_desc_id_vec_.clear();
naive_produced_rs_.ForChosenFrontRegst(IsChosenRegstDescId, [&](Regst* regst) {
CHECK(regst->regst_desc()->regst_desc_type().has_ctrl_regst_desc());
int64_t real_consumer_cnt = HandleRegstToConsumer(regst, [](int64_t) { return true; });
int64_t real_consumer_cnt = HandleRegstToConsumer(regst);
if (real_consumer_cnt > 0) { tmp_regst_desc_id_vec_.push_back(regst->regst_desc_id()); }
});
naive_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);
}

int64_t Actor::HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor) {
int64_t Actor::HandleRegstToConsumer(Regst* regst) {
auto regst_reading_cnt_it = produced_regst2reading_cnt_.find(regst);
CHECK_EQ(regst_reading_cnt_it->second, 0);
regst->set_act_id(act_id_);

int64_t real_consumer_cnt = 0;
for (int64_t consumer : regst->consumers_actor_id()) {
if (!IsAllowedActor(consumer)) { continue; }
EnqueueAsyncMsg(ActorMsg::BuildRegstMsgToConsumer(actor_id_, consumer, regst));
real_consumer_cnt += 1;
}
Expand Down Expand Up @@ -568,70 +538,32 @@ void Actor::AsyncLaunchKernel(const KernelCtx& kernel_ctx) {
});
}

void Actor::HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor) {
void Actor::HandleProducedNaiveDataRegstToConsumer() {
tmp_regst_desc_id_vec_.clear();
naive_produced_rs_.ForEachFrontRegst([&](Regst* regst) {
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {
if (RegstPreProcess(regst) == false) { return; }
int64_t real_consumer_cnt = HandleRegstToConsumer(regst, IsAllowedActor);
int64_t real_consumer_cnt = HandleRegstToConsumer(regst);
if (real_consumer_cnt > 0) { tmp_regst_desc_id_vec_.push_back(regst->regst_desc_id()); }
}
});
naive_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);
}

void Actor::HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess) {
HandleProducedNaiveDataRegstToConsumer(RegstPreProcess, [](int64_t) { return true; });
}

void Actor::HandleProducedNaiveDataRegstToConsumer(std::function<bool(int64_t)> IsAllowedActor) {
HandleProducedNaiveDataRegstToConsumer([](Regst*) { return true; }, IsAllowedActor);
}

void Actor::HandleProducedNaiveDataRegstToConsumer() {
HandleProducedNaiveDataRegstToConsumer([](Regst*) { return true; });
}

void Actor::HandleProducedInplaceDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor) {
void Actor::HandleProducedInplaceDataRegstToConsumer() {
tmp_regst_desc_id_vec_.clear();
inplace_produced_rs_.ForEachFrontRegst([&](Regst* regst) {
CHECK(regst->regst_desc()->regst_desc_type().has_data_regst_desc());
if (RegstPreProcess(regst) == false) { return; }
int64_t real_consumer_cnt = HandleRegstToConsumer(regst, IsAllowedActor);
int64_t real_consumer_cnt = HandleRegstToConsumer(regst);
if (real_consumer_cnt > 0) { tmp_regst_desc_id_vec_.push_back(regst->regst_desc_id()); }
});
inplace_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);
}

void Actor::HandleProducedInplaceDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess) {
HandleProducedInplaceDataRegstToConsumer(RegstPreProcess, [](int64_t) { return true; });
}

void Actor::HandleProducedInplaceDataRegstToConsumer(std::function<bool(int64_t)> IsAllowedActor) {
HandleProducedInplaceDataRegstToConsumer([](Regst*) { return true; }, IsAllowedActor);
}

void Actor::HandleProducedInplaceDataRegstToConsumer() {
HandleProducedInplaceDataRegstToConsumer([](Regst*) { return true; });
}

void Actor::AsyncSendRegstMsgToConsumer(Regst* regst) {
AsyncSendRegstMsgToConsumer(regst, [](int64_t) { return true; });
}

void Actor::AsyncSendRegstMsgToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor) {
int64_t real_consumer_cnt = HandleRegstToConsumer(regst, IsAllowedActor);
if (real_consumer_cnt > 0) { naive_produced_rs_.TryPopFrontRegst(regst->regst_desc_id()); }
}

void Actor::HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst) {
void Actor::HandleConsumedNaiveDataRegstToProducer() {
tmp_regst_desc_id_vec_.clear();
naive_consumed_rs_.ForEachFrontRegst([&](int64_t regst_desc_id, Regst* regst) {
if (IsConsumedCtrlRegstDescId(regst_desc_id)) { return; }
if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {
if (IsAllowedRegst(regst) == false) { return; }
// must access regst before sending it to producer
tmp_regst_desc_id_vec_.push_back(regst->regst_desc_id());
EnqueueAsyncMsg(
Expand Down
16 changes: 2 additions & 14 deletions oneflow/core/actor/actor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ class Actor {
int64_t ReadingCnt4ProducedRegst(Regst* regst) const;
void IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val);
void IncreaseTotalReadingCnt(int64_t val) { total_reading_cnt_ += val; }
int64_t GetPieceId4NaiveCurReadableDataRegst() const;
int64_t GetPieceId4NaiveOrInplaceCurReadableDataRegst() const;

// Msg Handler
void set_msg_handler(MsgHandler val) { msg_handler_ = val; }
Expand All @@ -107,20 +105,10 @@ class Actor {

// Util For Derived Actor to Send Msg
void EnqueueAsyncMsg(const ActorMsg&);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess);
void HandleProducedNaiveDataRegstToConsumer(std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedNaiveDataRegstToConsumer();
void HandleProducedInplaceDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedInplaceDataRegstToConsumer(std::function<bool(Regst*)> RegstPreProcess);
void HandleProducedInplaceDataRegstToConsumer(std::function<bool(int64_t)> IsAllowedActor);
void HandleProducedInplaceDataRegstToConsumer();
void AsyncSendRegstMsgToConsumer(Regst* regst);
void AsyncSendRegstMsgToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor);

void HandleConsumedNaiveDataRegstToProducer(std::function<bool(Regst*)> IsAllowedRegst);
void HandleConsumedNaiveDataRegstToProducer();
void AsyncSendRegstMsgToProducer(Regst*);
void AsyncSendRegstMsgToProducer(Regst*, int64_t producer);
void AsyncSendEORDMsgForAllProducedRegstDesc();
Expand All @@ -145,7 +133,7 @@ class Actor {
}
Regst* GetSoleProducedRegst4RegstDescId(int64_t regst_desc_id) const;
void ForEachProducedRegst(const std::function<void(Regst*)>&) const;
int64_t HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)> IsAllowedActor);
int64_t HandleRegstToConsumer(Regst* regst);

protected:
bool IsConsumedCtrlRegstDescId(int64_t regst_desc_id) {
Expand Down
9 changes: 7 additions & 2 deletions oneflow/core/actor/actor_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,14 @@ int64_t ActorMsg::regst_desc_id() const {
}
}

int64_t ActorMsg::piece_id() const {
int64_t ActorMsg::comm_net_sequence_number() const {
CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);
return regst_wrapper_.regst_status.piece_id;
return regst_wrapper_.comm_net_sequence_number;
}

void ActorMsg::set_comm_net_sequence_number(int64_t sequence_number) {
CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);
regst_wrapper_.comm_net_sequence_number = sequence_number;
}

int64_t ActorMsg::act_id() const {
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/actor/actor_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class ActorMsg final {
ActorCmd actor_cmd() const;
Regst* regst() const;
int64_t regst_desc_id() const;
int64_t piece_id() const;
int64_t act_id() const;
void* comm_net_token() const;
void set_comm_net_token(void* token);
Expand All @@ -62,6 +61,8 @@ class ActorMsg final {
uint8_t user_data_size() const;
const void* user_data() const;
bool IsDataRegstMsgToConsumer() const;
int64_t comm_net_sequence_number() const;
void set_comm_net_sequence_number(int64_t sequence_number);

// Serialize
template<typename StreamT>
Expand All @@ -77,6 +78,7 @@ class ActorMsg final {
struct RegstWrapper {
Regst* regst;
void* comm_net_token;
int64_t comm_net_sequence_number;
RegstStatus regst_status;
bool has_sole_empty_blob;
bool is_data_regst_to_consumer;
Expand Down
18 changes: 17 additions & 1 deletion oneflow/core/actor/actor_message_bus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,23 @@ void ActorMsgBus::SendMsg(const ActorMsg& msg) {
if (dst_machine_id == GlobalProcessCtx::Rank()) {
SendMsgWithoutCommNet(msg);
} else {
Global<CommNet>::Get()->SendActorMsg(dst_machine_id, msg);
if (msg.IsDataRegstMsgToConsumer()) {
int64_t comm_net_sequence;
{
std::unique_lock<std::mutex> lock(
regst_desc_id_dst_actor_id2comm_net_sequence_number_mutex_);
int64_t& comm_net_sequence_ref =
regst_desc_id_dst_actor_id2comm_net_sequence_number_[std::make_pair(
msg.regst_desc_id(), msg.dst_actor_id())];
comm_net_sequence = comm_net_sequence_ref;
comm_net_sequence_ref += 1;
}
ActorMsg new_msg = msg;
new_msg.set_comm_net_sequence_number(comm_net_sequence);
Global<CommNet>::Get()->SendActorMsg(dst_machine_id, new_msg);
} else {
Global<CommNet>::Get()->SendActorMsg(dst_machine_id, msg);
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/actor/actor_message_bus.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class ActorMsgBus final {
private:
friend class Global<ActorMsgBus>;
ActorMsgBus() = default;
HashMap<std::pair<int64_t, int64_t>, int64_t>
regst_desc_id_dst_actor_id2comm_net_sequence_number_;
std::mutex regst_desc_id_dst_actor_id2comm_net_sequence_number_mutex_;
};

} // namespace oneflow
Expand Down
9 changes: 1 addition & 8 deletions oneflow/core/actor/boxing_zeros_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class BoxingZerosActor : public NaiveActor {

void VirtualActorInit(const TaskProto& task_proto) override {
NaiveActor::VirtualActorInit(task_proto);
piece_id_ = 0;
out_inited_ = false;
}

Expand All @@ -38,15 +37,9 @@ class BoxingZerosActor : public NaiveActor {
}

void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override {
int64_t piece_id = piece_id_;
HandleProducedNaiveDataRegstToConsumer([&](Regst* regst) {
regst->set_piece_id(piece_id);
return true;
});
piece_id_ += 1;
HandleProducedNaiveDataRegstToConsumer();
}

int64_t piece_id_;
bool out_inited_;
};

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/actor/case_compute_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void CaseCompActor::AsyncSendCustomizedProducedRegstMsgToConsumer() {
if (case_status_.cmd != kCaseCmdHandleOutput) { return; }
const int64_t regst_desc_id = out_bn_id2regst_desc_id_.at(case_status_.cur_selected_id);
Regst* const regst = regst_desc_id2produced_rs_.at(regst_desc_id).Front(regst_desc_id);
CHECK_GT(HandleRegstToConsumer(regst, [](int64_t) { return true; }), 0);
CHECK_GT(HandleRegstToConsumer(regst), 0);
regst_desc_id2produced_rs_.at(regst_desc_id).PopFrontRegsts({regst_desc_id});
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/actor/case_compute_actor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class CaseCompActor final : public CompActor {
bool IsInputOrOutputReady() const;
int64_t GetCurSelectId() const;

HashMap<int64_t, int64_t> regst_desc_id2piece_id_;
HashMap<int64_t, int64_t> out_bn_id2regst_desc_id_;
int64_t consumed_regst_desc_id_;
RegstSlot consumed_rs_;
Expand Down
9 changes: 1 addition & 8 deletions oneflow/core/actor/collective_boxing_generic_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,16 @@ class CollectiveBoxingGenericActor : public Actor {
void Act() override { AsyncLaunchKernel(GenDefaultKernelCtx()); }

void VirtualActorInit(const TaskProto&) override {
piece_id_ = 0;
OF_SET_MSG_HANDLER(&CollectiveBoxingGenericActor::HandlerNormal);
}

void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override {
HandleProducedNaiveDataRegstToConsumer([&](Regst* regst) {
regst->set_piece_id(piece_id_);
return true;
});
piece_id_ += 1;
HandleProducedNaiveDataRegstToConsumer();
}

void InitDeviceCtx(const ThreadCtx& thread_ctx) override {
mut_device_ctx().reset(new CollectiveBoxingDeviceCtx());
}

int64_t piece_id_ = 0;
};

REGISTER_ACTOR(TaskType::kCollectiveBoxingGeneric, CollectiveBoxingGenericActor);
Expand Down
Loading