@@ -77,7 +77,7 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
7777 for (size_t i = 0 ; i < adt.size (); i++) {
7878 ret.push_back (CopyTo (adt[i], ctx));
7979 }
80- return ADT (0 , ret.begin (), ret.end ());
80+ return ADT (adt-> tag , ret.begin (), ret.end ());
8181 }
8282}
8383
@@ -161,11 +161,8 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
161161 << " The number of provided parameters doesn't match the number of assigned devices" ;
162162 std::vector<ObjectRef> func_args (param_names.size ());
163163 for (int i = 1 ; i < args.size (); ++i) {
164- TVMContext ctx;
165- int device_type = vm_func.params_device_type [i - 1 ];
166- ctx.device_type = DLDeviceType (device_type);
167- // TODO(zhiics) Use virtual device id
168- ctx.device_id = 0 ;
164+ Index device_type = vm_func.params_device_type [i - 1 ];
165+ DLContext ctx = GetContext (device_type);
169166 ObjectRef obj = CopyTo (args[i], ctx);
170167 func_args[i - 1 ] = obj;
171168 }
@@ -178,15 +175,13 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
178175 }
179176}
180177
181- TVMContext VirtualMachine::GetContext (Index device_type) const {
182- CHECK (!ctxs_.empty ()) << " Context has not been initialized yet." ;
183-
184- const auto & cit = std::find_if (ctxs_.begin (), ctxs_.end (), [&device_type](const TVMContext& c) {
185- return device_type == static_cast <Index>(c.device_type );
186- });
178+ inline TVMContext VirtualMachine::GetContext (Index device_type) const {
179+ CHECK_GE (ctxs_.size (), device_type) << " ctxs_ list doesn't contain device:" << device_type;
187180
188- CHECK (cit != ctxs_.end ()) << " device type " << device_type << " not found int the context list." ;
189- return *cit;
181+ auto ctx = ctxs_[device_type];
182+ CHECK_EQ (static_cast <Index>(ctx.device_type ), device_type)
183+ << " device type " << device_type << " has not been initialized int the context list." ;
184+ return ctx;
190185}
191186
192187void VirtualMachine::PushFrame (Index arg_count, Index ret_pc, const VMFunction& vm_func) {
@@ -294,7 +289,14 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
294289void VirtualMachine::Init (const std::vector<TVMContext>& ctxs,
295290 const std::vector<AllocatorType>& alloc_types) {
296291 CHECK_EQ (ctxs.size (), alloc_types.size ());
297- ctxs_ = ctxs;
292+ // Cache the context
293+ for (const auto & it : ctxs) {
294+ auto dev_type = static_cast <size_t >(it.device_type );
295+ if (ctxs_.size () <= dev_type) {
296+ ctxs_.resize (dev_type + 1 );
297+ }
298+ ctxs_[dev_type] = it;
299+ }
298300 for (size_t i = 0 ; i < ctxs.size (); ++i) {
299301 auto alloc = MemoryManager::GetOrCreateAllocator (ctxs[i], alloc_types[i]);
300302 allocators_.emplace (ctxs[i], alloc);
@@ -484,9 +486,7 @@ void VirtualMachine::RunLoop() {
484486 goto main_loop;
485487 }
486488 case Opcode::AllocTensorReg: {
487- DLContext cpu_ctx;
488- cpu_ctx.device_type = kDLCPU ;
489- cpu_ctx.device_id = 0 ;
489+ DLContext cpu_ctx = GetContext (static_cast <Index>(kDLCPU ));
490490 auto shape_obj = ReadRegister (instr.alloc_tensor_reg .shape_register );
491491 NDArray shape_tensor = Downcast<NDArray>(CopyTo (shape_obj, cpu_ctx));
492492 auto shape = ToShape (shape_tensor);
@@ -566,9 +566,7 @@ void VirtualMachine::RunLoop() {
566566 }
567567 }
568568 case Opcode::ReshapeTensor: {
569- DLContext cpu_ctx;
570- cpu_ctx.device_type = kDLCPU ;
571- cpu_ctx.device_id = 0 ;
569+ DLContext cpu_ctx = GetContext (static_cast <Index>(kDLCPU ));
572570 auto tensor_obj = ReadRegister (instr.reshape_tensor .tensor );
573571 NDArray tensor_arr = Downcast<NDArray>(tensor_obj);
574572 // Read the shape from shape tensor
0 commit comments