Skip to content

Commit 47918ea

Browse files
committed
Use union in allocate_storage struct
1 parent 48b4b57 commit 47918ea

File tree

4 files changed

+52
-18
lines changed

4 files changed

+52
-18
lines changed

include/tvm/runtime/vm/bytecode.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,16 +197,18 @@ struct Instruction {
197197
RegName* free_vars;
198198
};
199199
struct /* AllocStorage Operands */ {
200-
/*! \brief The size of the allocation. */
201-
RegName allocation_size;
202200
/*! \brief The alignment of the allocation. */
203201
Index alignment;
204202
/*! \brief The hint of the dtype. */
205203
DLDataType dtype_hint;
206204
/*! \brief The number of dimensions. */
207205
uint32_t ndim;
208-
/*! \brief The shape of tensor. */
209-
int64_t* shape;
206+
union {
207+
/*! \brief The shape of tensor. */
208+
int64_t* shape;
209+
/*! \brief The size of the allocation. */
210+
RegName allocation_size;
211+
};
210212
/*! \brief The index of the device on which the allocation will be made. */
211213
Index device_index;
212214
} alloc_storage;

src/runtime/vm/bytecode.cc

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,11 @@ Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType
350350
instr.alloc_storage.dtype_hint = dtype_hint;
351351
instr.alloc_storage.device_index = device_index;
352352
instr.alloc_storage.ndim = static_cast<uint32_t>(shape.size());
353-
instr.alloc_storage.shape = new int64_t[shape.size()];
354-
for (size_t i = 0; i < shape.size(); ++i) {
355-
instr.alloc_storage.shape[i] = shape[i];
353+
if (instr.alloc_storage.ndim > 0) {
354+
instr.alloc_storage.shape = new int64_t[shape.size()];
355+
for (size_t i = 0; i < shape.size(); ++i) {
356+
instr.alloc_storage.shape[i] = shape[i];
357+
}
356358
}
357359
return instr;
358360
}
@@ -626,10 +628,15 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
626628
break;
627629
}
628630
case Opcode::AllocStorage: {
629-
os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " "
630-
<< instr.alloc_storage.alignment << " "
631-
<< "[" << StrJoin<int64_t>(instr.alloc_storage.shape, 0, instr.alloc_storage.ndim) << "] "
632-
<< DLDataType2String(instr.alloc_storage.dtype_hint) << " "
631+
os << "alloc_storage $" << instr.dst << " ";
632+
if (instr.alloc_storage.ndim > 0) {
633+
os << "[" << StrJoin<int64_t>(instr.alloc_storage.shape, 0, instr.alloc_storage.ndim)
634+
<< "] ";
635+
} else {
636+
os << "$" << instr.alloc_storage.allocation_size << " " << instr.alloc_storage.alignment
637+
<< " ";
638+
}
639+
os << DLDataType2String(instr.alloc_storage.dtype_hint) << " "
633640
<< instr.alloc_storage.device_index;
634641
break;
635642
}

src/runtime/vm/profiler/vm.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,21 @@ void VirtualMachineDebug::OpStartHook(Instruction instr) {
129129
{{"Argument Shapes",
130130
profiling::ShapeString(shape_tensor, instr.alloc_tensor_reg.dtype)}});
131131
} else if (instr.op == Opcode::AllocStorage) {
132-
auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
133132
std::ostringstream shape;
134-
shape << DLDataType2String(instr.alloc_storage.dtype_hint) << "[" << size << "]";
133+
if (instr.alloc_storage.ndim > 0) {
134+
std::string shape_str = "[";
135+
for (uint32_t i = 0; i < instr.alloc_storage.ndim; ++i) {
136+
if (i > 0) {
137+
shape_str += ", ";
138+
}
139+
shape_str += std::to_string(instr.alloc_storage.shape[i]);
140+
}
141+
shape_str += "]";
142+
shape << DLDataType2String(instr.alloc_storage.dtype_hint) << shape_str;
143+
} else {
144+
auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
145+
shape << DLDataType2String(instr.alloc_storage.dtype_hint) << "[" << size << "]";
146+
}
135147
Device dev = GetDevice(instr.alloc_storage.device_index);
136148
prof_.operator*().StartCall("VM::AllocStorage", dev,
137149
{{"VM::Argument Shapes", String(shape.str())}});

src/runtime/vm/vm.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -820,22 +820,35 @@ void VirtualMachine::RunLoop(const std::vector<Index>& output_tensor_reg_indices
820820
}
821821
case Opcode::AllocStorage: {
822822
OpStartHook(instr);
823-
auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
824-
auto alignment = instr.alloc_storage.alignment;
825823

826824
auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
827825
Allocator* allocator = GetAllocator(instr.alloc_storage.device_index);
828826
ICHECK(allocator) << "Did you forget to init the VirtualMachine with devices?";
829-
VLOG(2) << "allocating with allocation_size=" << size << ", alignment=" << alignment
830-
<< ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint)
831-
<< ", device_index=" << instr.alloc_storage.device_index;
832827
std::string mem_scope = exec_->virtual_devices[instr.alloc_storage.device_index].second;
833828

834829
if (instr.alloc_storage.ndim > 0) {
830+
std::string shape = "[";
831+
for (uint32_t i = 0; i < instr.alloc_storage.ndim; ++i) {
832+
if (i > 0) {
833+
shape += ", ";
834+
}
835+
shape += std::to_string(instr.alloc_storage.shape[i]);
836+
}
837+
shape += "]";
838+
VLOG(2) << "allocating with ndims=" << instr.alloc_storage.ndim << ", shape=" << shape
839+
<< ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint)
840+
<< ", device_index=" << instr.alloc_storage.device_index
841+
<< ", memory_scope=" << mem_scope;
835842
storage_obj->buffer =
836843
allocator->Alloc(instr.alloc_storage.ndim, instr.alloc_storage.shape,
837844
instr.alloc_storage.dtype_hint, mem_scope);
838845
} else {
846+
auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
847+
auto alignment = instr.alloc_storage.alignment;
848+
VLOG(2) << "allocating with allocation_size=" << size << ", alignment=" << alignment
849+
<< ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint)
850+
<< ", device_index=" << instr.alloc_storage.device_index
851+
<< ", memory_scope=" << mem_scope;
839852
storage_obj->buffer = allocator->Alloc(size, alignment, instr.alloc_storage.dtype_hint);
840853
}
841854
Storage storage(storage_obj);

0 commit comments

Comments
 (0)