Skip to content

Commit 465fd1b

Browse files
committed
cache context
1 parent c0048ad commit 465fd1b

File tree

3 files changed

+27
-26
lines changed

3 files changed

+27
-26
lines changed

python/tvm/runtime/vm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,13 @@ def _setup_ctx(self, ctx, memory_cfg):
309309
"""Init context and allocators."""
310310
ctxs = ctx
311311
if not isinstance(ctx, (list, tuple)):
312-
assert isinstance(ctx, tvm.runtime.TVMContext)
312+
if not isinstance(ctx, tvm.runtime.TVMContext):
313+
raise TypeError("ctx is expected to be TVMContex")
313314
ctxs = [ctx]
314-
# CPU is required for executing shape functions
315-
if ctx.device_type != tvm.cpu(0).device_type:
316-
ctxs.append(tvm.cpu())
315+
316+
# CPU is required for executing shape functions
317+
if not any(c.device_type == tvm.cpu().device_type for c in ctxs):
318+
ctxs.append(tvm.cpu())
317319

318320
default_alloc_type = VirtualMachine.POOLED_ALLOCATOR
319321
if memory_cfg is None:

src/runtime/vm/executable.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,8 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
644644
return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
645645
}
646646
case Opcode::AllocStorage: {
647-
DCHECK_GE(instr.fields.size(), 6U);
647+
// Number of fields = 7
648+
DCHECK_GE(instr.fields.size(), 7U);
648649
Index allocation_size = instr.fields[0];
649650
Index alignment = instr.fields[1];
650651

src/runtime/vm/vm.cc

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
7777
for (size_t i = 0; i < adt.size(); i++) {
7878
ret.push_back(CopyTo(adt[i], ctx));
7979
}
80-
return ADT(0, ret.begin(), ret.end());
80+
return ADT(adt->tag, ret.begin(), ret.end());
8181
}
8282
}
8383

@@ -161,11 +161,8 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
161161
<< "The number of provided parameters doesn't match the number of assigned devices";
162162
std::vector<ObjectRef> func_args(param_names.size());
163163
for (int i = 1; i < args.size(); ++i) {
164-
TVMContext ctx;
165-
int device_type = vm_func.params_device_type[i - 1];
166-
ctx.device_type = DLDeviceType(device_type);
167-
// TODO(zhiics) Use virtual device id
168-
ctx.device_id = 0;
164+
Index device_type = vm_func.params_device_type[i - 1];
165+
DLContext ctx = GetContext(device_type);
169166
ObjectRef obj = CopyTo(args[i], ctx);
170167
func_args[i - 1] = obj;
171168
}
@@ -178,15 +175,13 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
178175
}
179176
}
180177

181-
TVMContext VirtualMachine::GetContext(Index device_type) const {
182-
CHECK(!ctxs_.empty()) << "Context has not been initialized yet.";
183-
184-
const auto& cit = std::find_if(ctxs_.begin(), ctxs_.end(), [&device_type](const TVMContext& c) {
185-
return device_type == static_cast<Index>(c.device_type);
186-
});
178+
inline TVMContext VirtualMachine::GetContext(Index device_type) const {
179+
CHECK_GE(ctxs_.size(), device_type) << "ctxs_ list doesn't contain device:" << device_type;
187180

188-
CHECK(cit != ctxs_.end()) << "device type " << device_type << " not found int the context list.";
189-
return *cit;
181+
auto ctx = ctxs_[device_type];
182+
CHECK_EQ(static_cast<Index>(ctx.device_type), device_type)
183+
<< "device type " << device_type << " has not been initialized int the context list.";
184+
return ctx;
190185
}
191186

192187
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
@@ -294,7 +289,14 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
294289
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs,
295290
const std::vector<AllocatorType>& alloc_types) {
296291
CHECK_EQ(ctxs.size(), alloc_types.size());
297-
ctxs_ = ctxs;
292+
// Cache the context
293+
for (const auto& it : ctxs) {
294+
auto dev_type = static_cast<size_t>(it.device_type);
295+
if (ctxs_.size() <= dev_type) {
296+
ctxs_.resize(dev_type + 1);
297+
}
298+
ctxs_[dev_type] = it;
299+
}
298300
for (size_t i = 0; i < ctxs.size(); ++i) {
299301
auto alloc = MemoryManager::GetOrCreateAllocator(ctxs[i], alloc_types[i]);
300302
allocators_.emplace(ctxs[i], alloc);
@@ -484,9 +486,7 @@ void VirtualMachine::RunLoop() {
484486
goto main_loop;
485487
}
486488
case Opcode::AllocTensorReg: {
487-
DLContext cpu_ctx;
488-
cpu_ctx.device_type = kDLCPU;
489-
cpu_ctx.device_id = 0;
489+
DLContext cpu_ctx = GetContext(static_cast<Index>(kDLCPU));
490490
auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
491491
NDArray shape_tensor = Downcast<NDArray>(CopyTo(shape_obj, cpu_ctx));
492492
auto shape = ToShape(shape_tensor);
@@ -566,9 +566,7 @@ void VirtualMachine::RunLoop() {
566566
}
567567
}
568568
case Opcode::ReshapeTensor: {
569-
DLContext cpu_ctx;
570-
cpu_ctx.device_type = kDLCPU;
571-
cpu_ctx.device_id = 0;
569+
DLContext cpu_ctx = GetContext(static_cast<Index>(kDLCPU));
572570
auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor);
573571
NDArray tensor_arr = Downcast<NDArray>(tensor_obj);
574572
// Read the shape from shape tensor

0 commit comments

Comments
 (0)