Skip to content

Commit

Permalink
[VM] Recycle VMFrame (#16822)
Browse files Browse the repository at this point in the history
This PR recycles the VMFrame in VM which can help a bit
when function involves large frames.
  • Loading branch information
tqchen authored Mar 30, 2024
1 parent 5053a4f commit eb4175b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
35 changes: 32 additions & 3 deletions src/runtime/relax_vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,20 @@ struct VMFrame {

VMFrame(Index pc, Index register_file_size)
: return_pc(pc), register_file(register_file_size), caller_return_register(0) {}

void Clear() {
this->caller_return_register = 0;
this->call_arg_values.clear();
this->call_arg_tcodes.clear();
for (RegType& reg : register_file) {
reg = nullptr;
}
}

void ResetForRecycle(Index pc, Index register_file_size) {
this->return_pc = pc;
this->register_file.resize(register_file_size);
}
};

class VirtualMachineImpl : public VirtualMachine {
Expand Down Expand Up @@ -322,6 +336,8 @@ class VirtualMachineImpl : public VirtualMachine {
~FrameGuard() {
ICHECK_GT(vm->frames_.size(), 0);
vm->pc_ = vm->frames_.back()->return_pc;
vm->frames_.back()->Clear();
vm->frame_free_list_.emplace_back(std::move(vm->frames_.back()));
vm->frames_.pop_back();
}
};
Expand All @@ -335,15 +351,23 @@ class VirtualMachineImpl : public VirtualMachine {
* \return A RAII wrapper that pops the frame when going out of scope.
*/
FrameGuard PushFrame(Index ret_pc, const VMFuncInfo& vm_func) {
return FrameGuard(this, std::make_unique<VMFrame>(ret_pc, vm_func.register_file_size));
std::unique_ptr<VMFrame> new_frame;
if (!frame_free_list_.empty()) {
new_frame = std::move(frame_free_list_.back());
frame_free_list_.pop_back();
new_frame->ResetForRecycle(ret_pc, vm_func.register_file_size);
} else {
new_frame = std::make_unique<VMFrame>(ret_pc, vm_func.register_file_size);
}
return FrameGuard(this, std::move(new_frame));
}
/*!
* \brief Write to a VM register.
* \param frame current vm frame.
* \param reg The register to write to.
* \param obj The object to write to.
*/
void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) {
TVM_ALWAYS_INLINE void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) {
ICHECK_LT(reg, frame->register_file.size());
frame->register_file[reg] = obj;
}
Expand All @@ -353,7 +377,7 @@ class VirtualMachineImpl : public VirtualMachine {
* \param reg The register to read from.
* \return The value of the register.
*/
RegType ReadRegister(VMFrame* frame, RegName reg) {
TVM_ALWAYS_INLINE RegType ReadRegister(VMFrame* frame, RegName reg) {
if (reg < Instruction::kBeginSpecialReg) {
return frame->register_file[reg];
}
Expand Down Expand Up @@ -425,6 +449,11 @@ class VirtualMachineImpl : public VirtualMachine {
* \note: Use unique ptr to avoid re-allocation and copy when frames_ get resized.
*/
std::vector<std::unique_ptr<VMFrame>> frames_;
/*!
* \brief A free list of frame
*/
std::vector<std::unique_ptr<VMFrame>> frame_free_list_;

/*! \brief The virtual machine PC. */
Index pc_{0};
/*! \brief The special return register. */
Expand Down
54 changes: 54 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,58 @@ TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian
TVM_REGISTER_GLOBAL("testing.AcceptsVariant")
.set_body_typed([](Variant<String, Integer> arg) -> String { return arg->GetTypeKey(); });

/**
* Simple event logger that can be used for testing purposes
*/
class TestingEventLogger {
public:
struct Entry {
String event;
double time_us;
};

TestingEventLogger() {
entries_.reserve(1024);
start_ = std::chrono::high_resolution_clock::now();
}

void Record(String event) {
auto tend = std::chrono::high_resolution_clock::now();
double time_us = static_cast<double>((tend - start_).count()) / 1e3;
entries_.emplace_back(Entry{event, time_us});
}

void Reset() { entries_.clear(); }

void Dump() const {
for (const Entry& e : entries_) {
LOG(INFO) << e.event << "\t" << e.time_us << " us";
}
}

static TestingEventLogger* ThreadLocal() {
thread_local TestingEventLogger inst;
return &inst;
}

private:
std::chrono::high_resolution_clock::time_point start_;
std::vector<Entry> entries_;
};

TVM_REGISTER_GLOBAL("testing.record_event").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args.size() != 0 && args[0].type_code() == kTVMStr) {
TestingEventLogger::ThreadLocal()->Record(args[0]);
} else {
TestingEventLogger::ThreadLocal()->Record("X");
}
});

TVM_REGISTER_GLOBAL("testing.reset_events").set_body([](TVMArgs args, TVMRetValue* rv) {
TestingEventLogger::ThreadLocal()->Reset();
});

TVM_REGISTER_GLOBAL("testing.dump_events").set_body_typed([]() {
TestingEventLogger::ThreadLocal()->Dump();
});
} // namespace tvm

0 comments on commit eb4175b

Please sign in to comment.