Skip to content

Commit 3984459

Browse files
baoxinqiwrongtest
authored andcommitted
feat: Use TF list input op def
1 parent 8ac182f commit 3984459

File tree

5 files changed

+41
-196
lines changed

5 files changed

+41
-196
lines changed

apps/tf_tvmdsoop/tests/test_tfop_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def export_gpu_add_lib():
6161

6262
def test_add(session, lib_path, tf_device):
6363
"""test add lib with TensorFlow wrapper"""
64-
module = tf_op.Module(lib_path)
64+
module = tf_op.OpModule(lib_path)
6565

6666
left = tf.placeholder("float32", shape=[4])
6767
right = tf.placeholder("float32", shape=[4])

python/tvm/contrib/tf_op/module.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,11 @@ def __init__(self, lib_path, func_name, output_dtype, output_shape):
6767
elif output_shape is not None:
6868
self.dynamic_output_shape = self._pack_shape_tensor(output_shape)
6969

70-
# delay op initialization to where Func.apply() get called first time
71-
self.tvm_dso_op = None
7270
self.module = load_library.load_op_library('tvm_dso_op.so')
71+
self.tvm_dso_op = self.module.tvm_dso_op
7372

7473
def apply(self, *params):
75-
if self.tvm_dso_op is None:
76-
num_inputs = len(params)
77-
self.tvm_dso_op = getattr(self.module, "tvm_dso_op%s" % num_inputs)
78-
79-
return self.tvm_dso_op(*params,
74+
return self.tvm_dso_op(params,
8075
dynamic_output_shape=self.dynamic_output_shape,
8176
static_output_shape=self.static_output_shape,
8277
has_static_output_shape=self.has_static_output_shape,

src/contrib/tf_op/index_seq.h

Lines changed: 0 additions & 61 deletions
This file was deleted.

src/contrib/tf_op/tvm_dso_op_kernels.cc

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include <tvm/runtime/packed_func.h>
2727
#include <tvm/runtime/registry.h>
2828

29-
#include "index_seq.h"
3029
#include "tensorflow/core/framework/op_kernel.h"
3130

3231
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -37,6 +36,10 @@ using tensorflow::OpKernel;
3736
using tensorflow::OpKernelConstruction;
3837
using tensorflow::OpKernelContext;
3938

39+
using tvm::runtime::TVMArgs;
40+
using tvm::runtime::TVMArgsSetter;
41+
using tvm::runtime::TVMRetValue;
42+
4043
// Op utility trait for diffrent device type template
4144
template <typename DEVICE_TYPE>
4245
class TVMDSOOpTrait;
@@ -192,7 +195,7 @@ class TVMDSOOpTrait<GPUDevice> {
192195
};
193196
#endif
194197

195-
template <typename DEVICE_TYPE, int NUM_INPUTS>
198+
template <typename DEVICE_TYPE>
196199
class TVMDSOOp : public OpKernel {
197200
private:
198201
tvm::runtime::PackedFunc tvm_func;
@@ -225,9 +228,14 @@ class TVMDSOOp : public OpKernel {
225228
}
226229

227230
void Compute(tensorflow::OpKernelContext* context) override {
228-
DLTensor args[NUM_INPUTS + 1];
229-
TensorAsBuf buf_info[NUM_INPUTS];
230-
ShapeContainer shapes[NUM_INPUTS];
231+
// the last input is output shape spec
232+
const int num_inputs = context->num_inputs() - 1;
233+
234+
// total args = input args + 1
235+
int num_total_args = num_inputs + 1;
236+
std::vector<DLTensor> args(num_total_args);
237+
std::vector<TensorAsBuf> buf_info(num_inputs);
238+
std::vector<ShapeContainer> shapes(num_inputs);
231239

232240
tensorflow::Status status;
233241
int device_id = TVMDSOOpTrait<DEVICE_TYPE>::device_id(context);
@@ -237,7 +245,7 @@ class TVMDSOOp : public OpKernel {
237245

238246
// Get output shape
239247
tensorflow::TensorShape output_shape;
240-
auto& output_shape_tensor = context->input(NUM_INPUTS);
248+
auto& output_shape_tensor = context->input(num_inputs);
241249
if (has_static_output_shape) {
242250
// use static output shape
243251
const tensorflow::int64* dims = static_output_shape.data();
@@ -250,7 +258,7 @@ class TVMDSOOp : public OpKernel {
250258
output_shape = context->input(0).shape();
251259
}
252260

253-
for (int i = 0; i < NUM_INPUTS; ++i) {
261+
for (int i = 0; i < num_inputs; ++i) {
254262
// Grab the input tensor
255263
auto& input_tensor = context->input(i);
256264

@@ -279,32 +287,26 @@ class TVMDSOOp : public OpKernel {
279287
output.device_type = device_type;
280288
EnsureAlignment(context, *output_tensor, &output);
281289

282-
status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[NUM_INPUTS]);
290+
status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[num_inputs]);
283291
OP_REQUIRES_OK(context, status);
284292

285-
apply_variadic_by_ptrs(tvm_func, args);
293+
// Prepare PackedFunc arguments
294+
std::vector<TVMValue> tvm_values(num_total_args);
295+
std::vector<int> tvm_type_codes(num_total_args);
296+
TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
297+
for (int k = 0; k < num_total_args; ++k) {
298+
setter(k, &args[k]);
299+
}
300+
TVMRetValue rv;
301+
tvm_func.CallPacked(TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv);
286302

287303
output.CopyToOrigin();
288304
}
289305
};
290306

291307
#ifdef TF_TVMDSOOP_ENABLE_GPU
292-
#define REGISTER_TFTVM_KERNEL(n) \
293-
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_CPU), \
294-
TVMDSOOp<CPUDevice, n>); \
295-
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_GPU), \
296-
TVMDSOOp<GPUDevice, n>);
308+
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
309+
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_GPU), TVMDSOOp<GPUDevice>);
297310
#else
298-
#define REGISTER_TFTVM_KERNEL(n) \
299-
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_CPU), \
300-
TVMDSOOp<CPUDevice, n>);
311+
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
301312
#endif
302-
303-
REGISTER_TFTVM_KERNEL(1)
304-
REGISTER_TFTVM_KERNEL(2)
305-
REGISTER_TFTVM_KERNEL(3)
306-
REGISTER_TFTVM_KERNEL(4)
307-
REGISTER_TFTVM_KERNEL(5)
308-
REGISTER_TFTVM_KERNEL(6)
309-
REGISTER_TFTVM_KERNEL(7)
310-
REGISTER_TFTVM_KERNEL(8)

src/contrib/tf_op/tvm_dso_ops.cc

Lines changed: 10 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -19,104 +19,13 @@
1919

2020
#include "tensorflow/core/framework/op.h"
2121

22-
#define REGISTER_TFTVM_OP(n) \
23-
REGISTER_OP("TvmDsoOp" #n) \
24-
.Output("output: output_dtype") \
25-
.Attr("lib_path: string") \
26-
.Attr("func_name: string") \
27-
.Attr("output_dtype: {int32, int64, float} = DT_FLOAT") \
28-
.Attr("static_output_shape: list(int) >= 0 = []") \
29-
.Attr("has_static_output_shape: bool")
30-
31-
REGISTER_TFTVM_OP(1).Input("input: T").Attr("T: type").Input("dynamic_output_shape: int64");
32-
33-
REGISTER_TFTVM_OP(2)
34-
.Input("input1: T1")
35-
.Attr("T1: type")
36-
.Input("input2: T2")
37-
.Attr("T2: type")
38-
.Input("dynamic_output_shape: int64");
39-
40-
REGISTER_TFTVM_OP(3)
41-
.Input("input1: T1")
42-
.Attr("T1: type")
43-
.Input("input2: T2")
44-
.Attr("T2: type")
45-
.Input("input3: T3")
46-
.Attr("T3: type")
47-
.Input("dynamic_output_shape: int64");
48-
49-
REGISTER_TFTVM_OP(4)
50-
.Input("input1: T1")
51-
.Attr("T1: type")
52-
.Input("input2: T2")
53-
.Attr("T2: type")
54-
.Input("input3: T3")
55-
.Attr("T3: type")
56-
.Input("input4: T4")
57-
.Attr("T4: type")
58-
.Input("dynamic_output_shape: int64");
59-
60-
REGISTER_TFTVM_OP(5)
61-
.Input("input1: T1")
62-
.Attr("T1: type")
63-
.Input("input2: T2")
64-
.Attr("T2: type")
65-
.Input("input3: T3")
66-
.Attr("T3: type")
67-
.Input("input4: T4")
68-
.Attr("T4: type")
69-
.Input("input5: T5")
70-
.Attr("T5: type")
71-
.Input("dynamic_output_shape: int64");
72-
73-
REGISTER_TFTVM_OP(6)
74-
.Input("input1: T1")
75-
.Attr("T1: type")
76-
.Input("input2: T2")
77-
.Attr("T2: type")
78-
.Input("input3: T3")
79-
.Attr("T3: type")
80-
.Input("input4: T4")
81-
.Attr("T4: type")
82-
.Input("input5: T5")
83-
.Attr("T5: type")
84-
.Input("input6: T6")
85-
.Attr("T6: type")
86-
.Input("dynamic_output_shape: int64");
87-
88-
REGISTER_TFTVM_OP(7)
89-
.Input("input1: T1")
90-
.Attr("T1: type")
91-
.Input("input2: T2")
92-
.Attr("T2: type")
93-
.Input("input3: T3")
94-
.Attr("T3: type")
95-
.Input("input4: T4")
96-
.Attr("T4: type")
97-
.Input("input5: T5")
98-
.Attr("T5: type")
99-
.Input("input6: T6")
100-
.Attr("T6: type")
101-
.Input("input7: T7")
102-
.Attr("T7: type")
103-
.Input("dynamic_output_shape: int64");
104-
105-
REGISTER_TFTVM_OP(8)
106-
.Input("input1: T1")
107-
.Attr("T1: type")
108-
.Input("input2: T2")
109-
.Attr("T2: type")
110-
.Input("input3: T3")
111-
.Attr("T3: type")
112-
.Input("input4: T4")
113-
.Attr("T4: type")
114-
.Input("input5: T5")
115-
.Attr("T5: type")
116-
.Input("input6: T6")
117-
.Attr("T6: type")
118-
.Input("input7: T7")
119-
.Attr("T7: type")
120-
.Input("input8: T8")
121-
.Attr("T8: type")
122-
.Input("dynamic_output_shape: int64");
22+
REGISTER_OP("TvmDsoOp")
23+
.Input("input_args: ListT")
24+
.Attr("ListT: list({int8, int32, int64, float16, float32})")
25+
.Input("dynamic_output_shape: int64")
26+
.Output("output: output_dtype")
27+
.Attr("lib_path: string")
28+
.Attr("func_name: string")
29+
.Attr("output_dtype: {int32, int64, float} = DT_FLOAT")
30+
.Attr("static_output_shape: list(int) >= 0 = []")
31+
.Attr("has_static_output_shape: bool");

0 commit comments

Comments
 (0)