@@ -193,35 +193,26 @@ TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause
193193 return else_branch;
194194}
195195
196- std::vector<int64_t > ToAllocTensorShape64 (NDArray shape) {
196+ std::vector<int64_t > ToAllocTensorShape (NDArray shape) {
197197 std::vector<int64_t > raw_shape;
198- DLTensor tensor = shape.ToDLPack ()->dl_tensor ;
199- CHECK_EQ (tensor.ndim , 1u );
200- CHECK_EQ (tensor.dtype .code , 0U ) << " found " << tensor.dtype .code ;
201-
202- // TODO(@jroesch): we really need to standaridize the bit width of
203- // all of the shape manipulating code.
204- CHECK_EQ (tensor.dtype .bits , 64 ) << " found " << tensor.dtype .bits ;
205- int64_t * int_ptr = reinterpret_cast <int64_t *>(tensor.data );
206- for (auto i = 0 ; i < tensor.shape [0 ]; i++) {
207- raw_shape.push_back (int_ptr[i]);
208- }
209- return raw_shape;
210- }
211-
212-
213- std::vector<int64_t > ToAllocTensorShape32 (NDArray shape) {
214- std::vector<int64_t > raw_shape;
215- DLTensor tensor = shape.ToDLPack ()->dl_tensor ;
216- CHECK_EQ (tensor.ndim , 1u );
217- CHECK_EQ (tensor.dtype .code , 0U ) << " found " << tensor.dtype .code ;
218-
219- // TODO(@jroesch): we really need to standaridize the bit width of
220- // all of the shape manipulating code.
221- CHECK_LE (tensor.dtype .bits , 32 ) << " found " << tensor.dtype .bits ;
222- int32_t * int_ptr = reinterpret_cast <int32_t *>(tensor.data );
223- for (auto i = 0 ; i < tensor.shape [0 ]; i++) {
224- raw_shape.push_back (static_cast <int64_t >(int_ptr[i]));
198+ CHECK_EQ (shape->ndim , 1u );
199+ CHECK_EQ (shape->dtype .code , 0U )
200+ << " The dtype of constant shape must be int32 or int64, but got "
201+ << DLDataType2String (shape->dtype );
202+ CHECK (shape->dtype .bits == 64 || shape->dtype .bits == 32 )
203+ << " The dtype of constant shape must be int32 or int64, but got"
204+ << DLDataType2String (shape->dtype );
205+
206+ if (shape->dtype .bits == 64 ) {
207+ int64_t * int_ptr = reinterpret_cast <int64_t *>(shape->data );
208+ for (auto i = 0 ; i < shape->shape [0 ]; i++) {
209+ raw_shape.push_back (int_ptr[i]);
210+ }
211+ } else { // int32
212+ int32_t * int_ptr = reinterpret_cast <int32_t *>(shape->data );
213+ for (auto i = 0 ; i < shape->shape [0 ]; i++) {
214+ raw_shape.push_back (static_cast <int64_t >(int_ptr[i]));
215+ }
225216 }
226217 return raw_shape;
227218}
@@ -546,17 +537,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
546537
547538 if (const_shape) {
548539 NDArray shape = const_shape->data ;
549- std::vector<int64_t > raw_shape;
550- DLTensor tensor = shape.ToDLPack ()->dl_tensor ;
551- // TODO(@jroesch): we need to get an RFC done to standarize this
552- if (tensor.dtype .bits == 64 ) {
553- raw_shape = ToAllocTensorShape64 (shape);
554- } else if (tensor.dtype .bits == 32 ) {
555- raw_shape = ToAllocTensorShape32 (shape);
556- } else {
557- LOG (FATAL) << " unsupported bitwidth: " << tensor.dtype .bits ;
558- }
559-
540+ // TODO(@jroesch): we need to get an RFC done to standarize shape dtype
541+ std::vector<int64_t > raw_shape = ToAllocTensorShape (shape);
560542 // Add context field.
561543 Emit (Instruction::AllocTensor (storage_register, raw_shape, dtype, NewRegister ()));
562544 } else {
0 commit comments