2626#include < tvm/runtime/packed_func.h>
2727#include < tvm/runtime/registry.h>
2828
29- #include " index_seq.h"
3029#include " tensorflow/core/framework/op_kernel.h"
3130
3231typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -37,6 +36,10 @@ using tensorflow::OpKernel;
3736using tensorflow::OpKernelConstruction;
3837using tensorflow::OpKernelContext;
3938
39+ using tvm::runtime::TVMArgs;
40+ using tvm::runtime::TVMArgsSetter;
41+ using tvm::runtime::TVMRetValue;
42+
4043// Op utility trait for diffrent device type template
4144template <typename DEVICE_TYPE>
4245class TVMDSOOpTrait ;
@@ -192,7 +195,7 @@ class TVMDSOOpTrait<GPUDevice> {
192195};
193196#endif
194197
195- template <typename DEVICE_TYPE, int NUM_INPUTS >
198+ template <typename DEVICE_TYPE>
196199class TVMDSOOp : public OpKernel {
197200 private:
198201 tvm::runtime::PackedFunc tvm_func;
@@ -225,9 +228,14 @@ class TVMDSOOp : public OpKernel {
225228 }
226229
227230 void Compute (tensorflow::OpKernelContext* context) override {
228- DLTensor args[NUM_INPUTS + 1 ];
229- TensorAsBuf buf_info[NUM_INPUTS];
230- ShapeContainer shapes[NUM_INPUTS];
231+ // the last input is output shape spec
232+ const int num_inputs = context->num_inputs () - 1 ;
233+
234+ // total args = input args + 1
235+ int num_total_args = num_inputs + 1 ;
236+ std::vector<DLTensor> args (num_total_args);
237+ std::vector<TensorAsBuf> buf_info (num_inputs);
238+ std::vector<ShapeContainer> shapes (num_inputs);
231239
232240 tensorflow::Status status;
233241 int device_id = TVMDSOOpTrait<DEVICE_TYPE>::device_id (context);
@@ -237,7 +245,7 @@ class TVMDSOOp : public OpKernel {
237245
238246 // Get output shape
239247 tensorflow::TensorShape output_shape;
240- auto & output_shape_tensor = context->input (NUM_INPUTS );
248+ auto & output_shape_tensor = context->input (num_inputs );
241249 if (has_static_output_shape) {
242250 // use static output shape
243251 const tensorflow::int64* dims = static_output_shape.data ();
@@ -250,7 +258,7 @@ class TVMDSOOp : public OpKernel {
250258 output_shape = context->input (0 ).shape ();
251259 }
252260
253- for (int i = 0 ; i < NUM_INPUTS ; ++i) {
261+ for (int i = 0 ; i < num_inputs ; ++i) {
254262 // Grab the input tensor
255263 auto & input_tensor = context->input (i);
256264
@@ -279,32 +287,26 @@ class TVMDSOOp : public OpKernel {
279287 output.device_type = device_type;
280288 EnsureAlignment (context, *output_tensor, &output);
281289
282- status = MakeDLTensor (output, dl_ctx, output_shape_ptr, &args[NUM_INPUTS ]);
290+ status = MakeDLTensor (output, dl_ctx, output_shape_ptr, &args[num_inputs ]);
283291 OP_REQUIRES_OK (context, status);
284292
285- apply_variadic_by_ptrs (tvm_func, args);
293+ // Prepare PackedFunc arguments
294+ std::vector<TVMValue> tvm_values (num_total_args);
295+ std::vector<int > tvm_type_codes (num_total_args);
296+ TVMArgsSetter setter (tvm_values.data (), tvm_type_codes.data ());
297+ for (int k = 0 ; k < num_total_args; ++k) {
298+ setter (k, &args[k]);
299+ }
300+ TVMRetValue rv;
301+ tvm_func.CallPacked (TVMArgs (tvm_values.data (), tvm_type_codes.data (), num_total_args), &rv);
286302
287303 output.CopyToOrigin ();
288304 }
289305};
290306
291307#ifdef TF_TVMDSOOP_ENABLE_GPU
292- #define REGISTER_TFTVM_KERNEL (n ) \
293- REGISTER_KERNEL_BUILDER (Name (" TvmDsoOp" #n).Device (tensorflow::DEVICE_CPU), \
294- TVMDSOOp<CPUDevice, n>); \
295- REGISTER_KERNEL_BUILDER (Name (" TvmDsoOp" #n).Device (tensorflow::DEVICE_GPU), \
296- TVMDSOOp<GPUDevice, n>);
308+ REGISTER_KERNEL_BUILDER (Name (" TvmDsoOp" ).Device (tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
309+ REGISTER_KERNEL_BUILDER (Name (" TvmDsoOp" ).Device (tensorflow::DEVICE_GPU), TVMDSOOp<GPUDevice>);
297310#else
298- #define REGISTER_TFTVM_KERNEL (n ) \
299- REGISTER_KERNEL_BUILDER (Name (" TvmDsoOp" #n).Device (tensorflow::DEVICE_CPU), \
300- TVMDSOOp<CPUDevice, n>);
311+ REGISTER_KERNEL_BUILDER (Name (" TvmDsoOp" ).Device (tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
301312#endif
302-
303- REGISTER_TFTVM_KERNEL (1 )
304- REGISTER_TFTVM_KERNEL (2 )
305- REGISTER_TFTVM_KERNEL (3 )
306- REGISTER_TFTVM_KERNEL (4 )
307- REGISTER_TFTVM_KERNEL (5 )
308- REGISTER_TFTVM_KERNEL (6 )
309- REGISTER_TFTVM_KERNEL (7 )
310- REGISTER_TFTVM_KERNEL (8 )
0 commit comments