Skip to content

Commit

Permalink
Sequantial instruction (#4521)
Browse files Browse the repository at this point in the history
* add sequantial callback instruction

* add a test_case for sequential instruction type

* refactor RunLogicalInstruction/RunPhysicalInstruction

* refactor RunLogicalInstruction/RunPhysicalInstruction

* refactor front sequential instruction

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
lixinqi and oneflow-ci-bot authored Mar 26, 2021
1 parent 82b6846 commit bb4a310
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 48 deletions.
1 change: 1 addition & 0 deletions oneflow/core/vm/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ ObjectMsgPtr<InstructionMsg> InstructionMsg::MakeInferInstrMsg() const {
auto* stream_type_id = infer_instr_msg->mut_instr_type_id()->mut_stream_type_id();
CHECK_EQ(stream_type_id->interpret_type(), InterpretType::kCompute);
stream_type_id->CopyFrom(LookupInferStreamTypeId(*stream_type_id));
*infer_instr_msg->mutable_no_arg_callback() = no_arg_callback();
return infer_instr_msg;
}

Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/vm/instruction.msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ OBJECT_MSG_BEGIN(Instruction);

// links
OBJECT_MSG_DEFINE_LIST_LINK(instruction_link);
// `vm_stat_running_instruction_link` valid from instruction ready to instruction done
OBJECT_MSG_DEFINE_LIST_LINK(vm_stat_running_instruction_link);
OBJECT_MSG_DEFINE_LIST_LINK(pending_instruction_link);
OBJECT_MSG_DEFINE_LIST_LINK(front_seq_infer_instr_link);
OBJECT_MSG_DEFINE_LIST_LINK(front_seq_compute_instr_link);
Expand Down
79 changes: 57 additions & 22 deletions oneflow/core/vm/sequential_instruction_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "oneflow/core/common/util.h"
#include "oneflow/core/object_msg/flat_msg_view.h"
#include "oneflow/core/vm/control_stream_type.h"
#include "oneflow/core/vm/host_stream_type.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/vm/instruction.msg.h"
#include "oneflow/core/vm/instruction_operand.msg.h"
Expand All @@ -30,12 +31,7 @@ class RankFrontSeqCallbackInstructionType : public InstructionType {
RankFrontSeqCallbackInstructionType() = default;
virtual ~RankFrontSeqCallbackInstructionType() override = default;

using stream_type = ControlStreamType;

virtual bool IsFrontSequential() const { return true; }

void Infer(Instruction*) const override { UNIMPLEMENTED(); }
void Compute(Instruction*) const override { UNIMPLEMENTED(); }
bool IsFrontSequential() const override { return true; }

protected:
// clang-format off
Expand All @@ -44,9 +40,9 @@ class RankFrontSeqCallbackInstructionType : public InstructionType {
FLAT_MSG_VIEW_END(RankFrontSeqCallbackInstrOperand);
// clang-format on

void Run(VirtualMachine* vm, InstructionMsg* instr_msg) const {
FlatMsgView<RankFrontSeqCallbackInstrOperand> args(instr_msg->operand());
const auto& callback = instr_msg->no_arg_callback();
void Run(const InstructionMsg& instr_msg) const {
FlatMsgView<RankFrontSeqCallbackInstrOperand> args(instr_msg.operand());
const auto& callback = instr_msg.no_arg_callback();
if (args->process_rank() == GlobalProcessCtx::Rank()) {
CHECK(static_cast<bool>(callback));
(*callback)();
Expand All @@ -56,30 +52,69 @@ class RankFrontSeqCallbackInstructionType : public InstructionType {
}
};

class RankFrontSeqInferCallbackInstructionType final : public RankFrontSeqCallbackInstructionType {
class InferRankFrontSeqCallbackInstructionType final : public RankFrontSeqCallbackInstructionType {
public:
RankFrontSeqInferCallbackInstructionType() = default;
~RankFrontSeqInferCallbackInstructionType() override = default;
InferRankFrontSeqCallbackInstructionType() = default;
~InferRankFrontSeqCallbackInstructionType() override = default;

void Infer(VirtualMachine* vm, InstructionMsg* instr_msg) const override { Run(vm, instr_msg); }
void Compute(VirtualMachine* vm, InstructionMsg* instr_msg) const override { /* do nothing */
using stream_type = HostStreamType;

void Infer(Instruction* instruction) const override { Run(instruction->instr_msg()); }
void Compute(Instruction* instruction) const override { /* do nothing */
}
};
COMMAND(
RegisterInstructionType<RankFrontSeqInferCallbackInstructionType>("RankFrontSeqInferCallback"));
RegisterInstructionType<InferRankFrontSeqCallbackInstructionType>("InferRankFrontSeqCallback"));

class ComputeRankFrontSeqCallbackInstructionType final
: public RankFrontSeqCallbackInstructionType {
public:
ComputeRankFrontSeqCallbackInstructionType() = default;
~ComputeRankFrontSeqCallbackInstructionType() override = default;

using stream_type = HostStreamType;

void Infer(Instruction* instruction) const override { /* do nothing */
}
void Compute(Instruction* instruction) const override { Run(instruction->instr_msg()); }
};
COMMAND(RegisterInstructionType<ComputeRankFrontSeqCallbackInstructionType>(
"ComputeRankFrontSeqCallback"));

class RankFrontSeqComputeCallbackInstructionType final
class CtrlInferRankFrontSeqCallbackInstructionType final
: public RankFrontSeqCallbackInstructionType {
public:
RankFrontSeqComputeCallbackInstructionType() = default;
~RankFrontSeqComputeCallbackInstructionType() override = default;
CtrlInferRankFrontSeqCallbackInstructionType() = default;
~CtrlInferRankFrontSeqCallbackInstructionType() override = default;

using stream_type = ControlStreamType;

void Infer(VirtualMachine*, InstructionMsg* instr_msg) const override { Run(*instr_msg); }
void Compute(VirtualMachine*, InstructionMsg* instr_msg) const override { /* do nothing */
;
}
void Infer(Instruction* instruction) const override { UNIMPLEMENTED(); }
void Compute(Instruction* instruction) const override { UNIMPLEMENTED(); }
};
COMMAND(RegisterInstructionType<CtrlInferRankFrontSeqCallbackInstructionType>(
"CtrlInferRankFrontSeqCallback"));

class CtrlComputeRankFrontSeqCallbackInstructionType final
: public RankFrontSeqCallbackInstructionType {
public:
CtrlComputeRankFrontSeqCallbackInstructionType() = default;
~CtrlComputeRankFrontSeqCallbackInstructionType() override = default;

using stream_type = ControlStreamType;

void Infer(VirtualMachine* vm, InstructionMsg* instr_msg) const override { /* do nothing */
void Infer(VirtualMachine*, InstructionMsg* instr_msg) const override { /* do nothing */
}
void Compute(VirtualMachine* vm, InstructionMsg* instr_msg) const override { Run(vm, instr_msg); }
void Compute(VirtualMachine*, InstructionMsg* instr_msg) const override { Run(*instr_msg); }
void Infer(Instruction* instruction) const override { UNIMPLEMENTED(); }
void Compute(Instruction* instruction) const override { UNIMPLEMENTED(); }
};
COMMAND(RegisterInstructionType<RankFrontSeqComputeCallbackInstructionType>(
"RankFrontSeqComputeCallback"));
COMMAND(RegisterInstructionType<CtrlComputeRankFrontSeqCallbackInstructionType>(
"CtrlComputeRankFrontSeqCallback"));

} // namespace vm
} // namespace oneflow
36 changes: 27 additions & 9 deletions oneflow/core/vm/sequential_instruction_type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ struct GlobalProcessCtxScope {
}
};

TEST(ControlStreamType, front_seq_compute) {
TEST(SequentialInstruction, front_seq_compute) {
GlobalProcessCtxScope scope;
auto vm_desc = ObjectMsgPtr<VmDesc>::New(TestUtil::NewVmResourceDesc().Get());
TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"NewObject"});
TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(),
{"NewObject", "ComputeRankFrontSeqCallback"});
CachedObjectMsgAllocator allocator(20, 100);
auto vm = ObjectMsgPtr<VirtualMachine>::NewFrom(&allocator, vm_desc.Get());
InstructionMsgList list;
Expand All @@ -67,28 +68,45 @@ TEST(ControlStreamType, front_seq_compute) {
NewInstruction("DeleteObject")->add_mut_operand(logical_object_id, AllMirroredObject()));
ASSERT_TRUE(vm->pending_msg_list().empty());
}
volatile bool finished = false;
volatile bool* finished_ptr = &finished;
int64_t sixsixsix = 0;
{
auto instruction = NewInstruction("RankFrontSeqComputeCallback");
auto instruction = NewInstruction("ComputeRankFrontSeqCallback");
instruction->add_int64_operand(GlobalProcessCtx::Rank());
const auto& Callback = [&]() {
*finished_ptr = true;
LOG(ERROR) << "Callback";
const auto Callback = [&]() { sixsixsix = 666; };
*instruction->mutable_no_arg_callback() = std::make_shared<std::function<void()>>(Callback);
list.EmplaceBack(std::move(instruction));
}
bool infer_finished = false;
{
auto instruction = NewInstruction("CtrlInferRankFrontSeqCallback");
instruction->add_int64_operand(GlobalProcessCtx::Rank());
const auto Callback = [&]() { infer_finished = true; };
*instruction->mutable_no_arg_callback() = std::make_shared<std::function<void()>>(Callback);
list.EmplaceBack(std::move(instruction));
}
bool compute_finished = false;
bool is_666 = false;
{
auto instruction = NewInstruction("CtrlComputeRankFrontSeqCallback");
instruction->add_int64_operand(GlobalProcessCtx::Rank());
const auto Callback = [&]() {
is_666 = sixsixsix == 666;
compute_finished = true;
};
*instruction->mutable_no_arg_callback() = std::make_shared<std::function<void()>>(Callback);
list.EmplaceBack(std::move(instruction));
}
vm->Receive(&list);
BlockingCounter bc(1);
std::thread t([&]() {
while (!*finished_ptr) {
while (!(infer_finished && compute_finished)) {
vm->Schedule();
OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); }
}
bc.Decrease();
});
bc.WaitUntilCntEqualZero();
ASSERT_TRUE(is_666);
ASSERT_TRUE(vm->Empty());
t.join();
}
Expand Down
38 changes: 22 additions & 16 deletions oneflow/core/vm/virtual_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void VirtualMachine::TryReleaseFinishedInstructions(
auto* running_instruction_list = stream->mut_running_instruction_list();
auto* front_seq_infer_list = mutable_front_seq_infer_instr_list();
auto* front_seq_compute_list = mutable_front_seq_compute_instr_list();
auto* vm_stat_running_list = mut_vm_stat_running_instruction_list();
while (true) {
auto* instruction_ptr = running_instruction_list->Begin();
if (instruction_ptr == nullptr || !instruction_ptr->Done()) { break; }
Expand All @@ -76,6 +77,7 @@ void VirtualMachine::TryReleaseFinishedInstructions(
} else {
UNIMPLEMENTED();
}
vm_stat_running_list->Erase(instruction_ptr);
stream->DeleteInstruction(running_instruction_list->Erase(instruction_ptr));
}
}
Expand Down Expand Up @@ -365,7 +367,9 @@ void VirtualMachine::DispatchAndPrescheduleInstructions(
ReadyInstructionList* ready_instruction_list) {
PrescheduledInstructionList prescheduled;
auto* active_stream_list = mut_active_stream_list();
auto* vm_stat_running_list = mut_vm_stat_running_instruction_list();
OBJECT_MSG_LIST_FOR_EACH_PTR(ready_instruction_list, instruction) {
vm_stat_running_list->PushBack(instruction);
auto* stream = instruction->mut_stream();
ready_instruction_list->MoveToDstBack(instruction, stream->mut_running_instruction_list());
if (stream->is_active_stream_link_empty()) { active_stream_list->PushBack(stream); }
Expand Down Expand Up @@ -438,25 +442,27 @@ void VirtualMachine::Receive(ObjectMsgPtr<InstructionMsg>&& compute_instr_msg) {
Receive(&instr_msg_list);
}

namespace {

template<typename ContainerT>
void TryRunFrontSeqInstructionsOnList(VirtualMachine* vm, ContainerT* front_seq_list) {
OBJECT_MSG_LIST_FOR_EACH(front_seq_list, instruction) {
const auto& instruction_type = instruction->instr_msg().instr_type_id().instruction_type();
if (!instruction_type.IsFrontSequential()) { break; }
front_seq_list->Erase(instruction.Mutable());
const auto& stream_type = instruction->stream().stream_type();
CHECK(stream_type.SharingVirtualMachineThread());
stream_type.Run(vm, instruction.Mutable());
void VirtualMachine::TryRunFrontSeqInstruction(
ContainerT* front_seq_list, /*out*/ ReadyInstructionList* ready_instruction_list) {
auto* instruction = front_seq_list->Begin();
if (instruction == nullptr) { return; }
const auto& instr_type_id = instruction->instr_msg().instr_type_id();
const auto& instruction_type = instr_type_id.instruction_type();
if (!instruction_type.IsFrontSequential()) { return; }
if (!instruction->is_vm_stat_running_instruction_link_empty()) { return; }
const StreamType& stream_type = instr_type_id.stream_type_id().stream_type();
if (stream_type.SharingVirtualMachineThread()) {
stream_type.Run(this, instruction);
front_seq_list->Erase(instruction);
} else {
ready_instruction_list->EmplaceBack(std::move(instruction));
}
}

} // namespace

void VirtualMachine::TryRunFrontSeqInstructions() {
TryRunFrontSeqInstructionsOnList(this, mutable_front_seq_infer_instr_list());
TryRunFrontSeqInstructionsOnList(this, mutable_front_seq_compute_instr_list());
void VirtualMachine::TryRunFrontSeqInstruction(ReadyInstructionList* ready_instruction_list) {
TryRunFrontSeqInstruction(mutable_front_seq_infer_instr_list(), ready_instruction_list);
TryRunFrontSeqInstruction(mutable_front_seq_compute_instr_list(), ready_instruction_list);
}

void VirtualMachine::Schedule() {
Expand All @@ -466,7 +472,7 @@ void VirtualMachine::Schedule() {
TryReleaseFinishedInstructions(stream, /*out*/ ready_instruction_list);
if (stream->running_instruction_list().empty()) { active_stream_list->Erase(stream); }
}
TryRunFrontSeqInstructions();
TryRunFrontSeqInstruction(/*out*/ ready_instruction_list);
auto* waiting_instruction_list = mut_waiting_instruction_list();
if (pending_msg_list().size() > 0) {
TmpPendingInstrMsgList tmp_pending_msg_list;
Expand Down
7 changes: 6 additions & 1 deletion oneflow/core/vm/virtual_machine.msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ OBJECT_MSG_BEGIN(VirtualMachine);
OBJECT_MSG_DEFINE_MUTEXED_LIST_HEAD(InstructionMsg, instr_msg_link, pending_msg_list);
OBJECT_MSG_DEFINE_LIST_HEAD(Instruction, instruction_link, waiting_instruction_list);
OBJECT_MSG_DEFINE_LIST_HEAD(Instruction, instruction_link, ready_instruction_list);
OBJECT_MSG_DEFINE_LIST_HEAD(Instruction, vm_stat_running_instruction_link,
vm_stat_running_instruction_list);
OBJECT_MSG_DEFINE_LIST_HEAD(Instruction, front_seq_infer_instr_link, front_seq_infer_instr_list);
OBJECT_MSG_DEFINE_LIST_HEAD(Instruction, front_seq_compute_instr_link, front_seq_compute_instr_list);
OBJECT_MSG_DEFINE_LIST_HEAD(Stream, active_stream_link, active_stream_list);
Expand All @@ -78,7 +80,10 @@ OBJECT_MSG_BEGIN(VirtualMachine);
using Id2LogicalObject = VirtualMachine::id2logical_object_ObjectMsgSkipListType;
using ActiveStreamList = VirtualMachine::active_stream_list_ObjectMsgListType;

void TryRunFrontSeqInstructions();
template<typename ContainerT>
void TryRunFrontSeqInstruction(ContainerT* front_seq_list,
/*out*/ ReadyInstructionList* ready_instruction_list);
void TryRunFrontSeqInstruction(/*out*/ ReadyInstructionList* ready_instruction_list);
void ReleaseInstruction(Instruction* instruction,
/*out*/ ReadyInstructionList* ready_instruction_list);
void TryReleaseFinishedInstructions(
Expand Down

0 comments on commit bb4a310

Please sign in to comment.