From eb4175bd3ddc99a5d902eed30476127a0abdc1dc Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 30 Mar 2024 16:30:51 -0400 Subject: [PATCH] [VM] Recycle VMFrame (#16822) This PR recycles the VMFrame in VM which can help a bit when function involves large frames. --- src/runtime/relax_vm/vm.cc | 35 +++++++++++++++++++++--- src/support/ffi_testing.cc | 54 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index d7f943d5f40f..618e68c4fd1f 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -177,6 +177,20 @@ struct VMFrame { VMFrame(Index pc, Index register_file_size) : return_pc(pc), register_file(register_file_size), caller_return_register(0) {} + + void Clear() { + this->caller_return_register = 0; + this->call_arg_values.clear(); + this->call_arg_tcodes.clear(); + for (RegType& reg : register_file) { + reg = nullptr; + } + } + + void ResetForRecycle(Index pc, Index register_file_size) { + this->return_pc = pc; + this->register_file.resize(register_file_size); + } }; class VirtualMachineImpl : public VirtualMachine { @@ -322,6 +336,8 @@ class VirtualMachineImpl : public VirtualMachine { ~FrameGuard() { ICHECK_GT(vm->frames_.size(), 0); vm->pc_ = vm->frames_.back()->return_pc; + vm->frames_.back()->Clear(); + vm->frame_free_list_.emplace_back(std::move(vm->frames_.back())); vm->frames_.pop_back(); } }; @@ -335,7 +351,15 @@ class VirtualMachineImpl : public VirtualMachine { * \return A RAII wrapper that pops the frame when going out of scope. */ FrameGuard PushFrame(Index ret_pc, const VMFuncInfo& vm_func) { - return FrameGuard(this, std::make_unique(ret_pc, vm_func.register_file_size)); + std::unique_ptr new_frame; + if (!frame_free_list_.empty()) { + new_frame = std::move(frame_free_list_.back()); + frame_free_list_.pop_back(); + new_frame->ResetForRecycle(ret_pc, vm_func.register_file_size); + } else { + new_frame = std::make_unique(ret_pc, vm_func.register_file_size); + } + return FrameGuard(this, std::move(new_frame)); } /*! * \brief Write to a VM register. @@ -343,7 +367,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param reg The register to write to. * \param obj The object to write to. */ - void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) { + TVM_ALWAYS_INLINE void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) { ICHECK_LT(reg, frame->register_file.size()); frame->register_file[reg] = obj; } @@ -353,7 +377,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param reg The register to read from. * \return The value of the register. */ - RegType ReadRegister(VMFrame* frame, RegName reg) { + TVM_ALWAYS_INLINE RegType ReadRegister(VMFrame* frame, RegName reg) { if (reg < Instruction::kBeginSpecialReg) { return frame->register_file[reg]; } @@ -425,6 +449,11 @@ class VirtualMachineImpl : public VirtualMachine { * \note: Use unique ptr to avoid re-allocation and copy when frames_ get resized. */ std::vector> frames_; + /*! + * \brief A free list of frame + */ + std::vector> frame_free_list_; + /*! \brief The virtual machine PC. */ Index pc_{0}; /*! \brief The special return register. */ diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 75b5a2527f76..aec57a1eb20d 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -189,4 +189,58 @@ TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian TVM_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); +/** + * Simple event logger that can be used for testing purposes + */ +class TestingEventLogger { + public: + struct Entry { + String event; + double time_us; + }; + + TestingEventLogger() { + entries_.reserve(1024); + start_ = std::chrono::high_resolution_clock::now(); + } + + void Record(String event) { + auto tend = std::chrono::high_resolution_clock::now(); + double time_us = static_cast((tend - start_).count()) / 1e3; + entries_.emplace_back(Entry{event, time_us}); + } + + void Reset() { entries_.clear(); } + + void Dump() const { + for (const Entry& e : entries_) { + LOG(INFO) << e.event << "\t" << e.time_us << " us"; + } + } + + static TestingEventLogger* ThreadLocal() { + thread_local TestingEventLogger inst; + return &inst; + } + + private: + std::chrono::high_resolution_clock::time_point start_; + std::vector entries_; +}; + +TVM_REGISTER_GLOBAL("testing.record_event").set_body([](TVMArgs args, TVMRetValue* rv) { + if (args.size() != 0 && args[0].type_code() == kTVMStr) { + TestingEventLogger::ThreadLocal()->Record(args[0]); + } else { + TestingEventLogger::ThreadLocal()->Record("X"); + } +}); + +TVM_REGISTER_GLOBAL("testing.reset_events").set_body([](TVMArgs args, TVMRetValue* rv) { + TestingEventLogger::ThreadLocal()->Reset(); +}); + +TVM_REGISTER_GLOBAL("testing.dump_events").set_body_typed([]() { + TestingEventLogger::ThreadLocal()->Dump(); +}); } // namespace tvm