Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy][TVM] TVM reduce added, support initial value #16818

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 176 additions & 31 deletions contrib/tvmop/core/fromnumeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@


import tvm
from .. import defop
from .. import defop, AllTypes, RealTypes
from ..utils import reduce_axes, assign_by_req


def _compute_sum(itype, otype, ndim, reduce1st_dim, req):
def _compute_with_initial(itype, otype, ndim, reducer, reduce1st_dim, req):
axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim]
a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=itype)
init = tvm.var('init', dtype='float64')
reduce_output = reduce_axes(a, axes, reducer, otype)
output_placeholder, final_output = assign_by_req(reduce_output, req, init, tvm.sum, itype=itype)
s = tvm.create_schedule(final_output.op)
return s, a, init, output_placeholder, final_output, [reduce_output, final_output]


def _compute(itype, otype, ndim, reducer, reduce1st_dim, req):
axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim]
a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=itype)
reduce_output = reduce_axes(a, axes, tvm.sum, otype)
Expand All @@ -30,34 +40,169 @@ def _compute_sum(itype, otype, ndim, reduce1st_dim, req):
return s, a, output_placeholder, final_output, [reduce_output, final_output]


@defop(name='sum_cpu', target='cpu', itype=['bool'],
otype=['float32', 'float64', 'int32', 'int64'],
@defop(name='sum_cpu', target='cpu', itype=['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool'],
otype=['float16', 'float32', 'float64', 'int8', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
attrs=["reduce1st_dim", "req"])
def _sum_cpu(itype, otype, ndim, reduce1st_dim, req):
s, a, output_placeholder, final_output, tensor_list = _compute_sum(
itype, otype, ndim, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, output_placeholder, final_output]


@defop(name='sum_gpu', target='gpu', itype=['bool'],
otype=['float32', 'float64', 'int32', 'int64'],
initial=[True, False], attrs=["reduce1st_dim", "req", "initial"])
def _sum_cpu(itype, otype, ndim, reduce1st_dim, req, initial):
if initial:
s, a, init, output_placeholder, final_output, tensor_list = \
_compute_with_initial(itype, otype, ndim, tvm.sum, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, init, output_placeholder, final_output]
else:
s, a, output_placeholder, final_output, tensor_list = _compute(
itype, otype, ndim, tvm.sum, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, output_placeholder, final_output]


@defop(name='sum_gpu', target='gpu', itype=['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool'],
otype=['float16', 'float32', 'float64', 'int8', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
initial=[True, False], attrs=["reduce1st_dim", "req", "initial"])
def _sum_gpu(itype, otype, ndim, reduce1st_dim, req, initial):
if initial:
s, a, init, output_placeholder, final_output, tensor_list = \
_compute_with_initial(itype, otype, ndim, tvm.sum, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, init, output_placeholder, final_output]
else:
s, a, output_placeholder, final_output, tensor_list = _compute(
itype, otype, ndim, tvm.sum, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, output_placeholder, final_output]


@defop(name='min_cpu', target='cpu', itype=['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool'],
otype=['float32', 'float64', 'int8', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
initial=[True, False], attrs=["reduce1st_dim", "req", "initial"])
def _min_cpu(itype, otype, ndim, reduce1st_dim, req, initial):
if initial:
s, a, init, output_placeholder, final_output, tensor_list = \
_compute_with_initial(itype, otype, ndim, tvm.min, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, init, output_placeholder, final_output]
else:
s, a, output_placeholder, final_output, tensor_list = _compute(
itype, otype, ndim, tvm.min, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, output_placeholder, final_output]


@defop(name='min_gpu', target='gpu', itype=['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool'],
otype=['float32', 'float64', 'int8', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
initial=[True, False], attrs=["reduce1st_dim", "req", "initial"])
def _min_gpu(itype, otype, ndim, reduce1st_dim, req, initial):
if initial:
s, a, init, output_placeholder, final_output, tensor_list = \
_compute_with_initial(itype, otype, ndim, tvm.min, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, init, output_placeholder, final_output]
else:
s, a, output_placeholder, final_output, tensor_list = _compute(
itype, otype, ndim, tvm.min, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, output_placeholder, final_output]


@defop(name='max_cpu', target='cpu', itype=['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool'],
otype=['float32', 'float64', 'int8', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
initial=[True, False], attrs=["reduce1st_dim", "req", "initial"])
def _max_cpu(itype, otype, ndim, reduce1st_dim, req, initial):
if initial:
s, a, init, output_placeholder, final_output, tensor_list = \
_compute_with_initial(itype, otype, ndim, tvm.max, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, init, output_placeholder, final_output]
else:
s, a, output_placeholder, final_output, tensor_list = _compute(
itype, otype, ndim, tvm.max, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, output_placeholder, final_output]


@defop(name='max_gpu', target='gpu', itype=['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool'],
otype=['float32', 'float64', 'int8', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
attrs=["reduce1st_dim", "req"])
def _sum_gpu(itype, otype, ndim, reduce1st_dim, req):
s, a, output_placeholder, final_output, tensor_list = _compute_sum(
itype, otype, ndim, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, output_placeholder, final_output]
initial=[True, False], attrs=["reduce1st_dim", "req", "initial"])
def _max_gpu(itype, otype, ndim, reduce1st_dim, req, initial):
if initial:
s, a, init, output_placeholder, final_output, tensor_list = \
_compute_with_initial(itype, otype, ndim, tvm.max, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, init, output_placeholder, final_output]
else:
s, a, output_placeholder, final_output, tensor_list = _compute(
itype, otype, ndim, tvm.max, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, output_placeholder, final_output]
28 changes: 22 additions & 6 deletions contrib/tvmop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,27 @@
RealTypes = ["float32", "float64", "float16"]


def assign_by_req(a, req, otype=None):
def assign_by_req(a, req, initial=None, reducer=None, itype=None, otype=None):
b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
if req == "kAddTo":
c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) + b[idx]
if otype else a[idx] + b[idx])
if initial is not None:
#initial casted to float32 first to avoid the nvcc error of unable to convert
#half type to long or int char types.
c = tvm.compute(a.shape, lambda *idx: reducer(a[idx].astype(otype) + \
b[idx], initial.astype(otype))
if otype else reducer(a[idx] + b[idx], \
initial.astype(itype).astype("float32").astype(a.dtype)))
else:
c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) + b[idx]
if otype else a[idx] + b[idx])
else:
c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) if otype else a[idx])
if initial is not None:
c = tvm.compute(a.shape, lambda *idx: reducer(a[idx], \
initial.astype(a.dtype)).astype(otype) \
if otype else reducer(a[idx], \
initial.astype(itype).astype("float32").astype(a.dtype)))
else:
c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) if otype else a[idx])
return b, c


Expand All @@ -42,12 +56,14 @@ def get_index(idx, ridx):
j += (val == 0)
k += (val != 0)
return tuple(ret)

ishape = X.shape
odim = (len(ishape) + 1 - axes[0]) // 2
oshape = [tvm.size_var() for _ in range(odim)]
ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)].astype(atype)
# input casted to float32 first to avoid the nvcc error of unable to convert half tpe to
# long or int char types.
ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)].astype("float32").astype(atype)
if atype else X[get_index(idx, ridx)],
axis=ridx), name='ret')
return ret
20 changes: 12 additions & 8 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ inline bool NeedSafeAcc(int itype, int otype) {

void TVMOpReduce(const OpContext& ctx, const TBlob& input,
const dmlc::optional<mxnet::Tuple<int>>& axis,
const dmlc::optional<double> initial,
const TBlob& output, const OpReqType req, const std::string& reducer_name);

template<typename xpu, typename reducer, bool safe_acc_hint = false, bool normalize = false,
Expand All @@ -255,9 +256,12 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
if (req[0] == kNullOp) return;
const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
#if !MXNET_USE_TVM_OP
if (param.initial.has_value()) {
LOG(FATAL) << "initial is not supported yet";
LOG(FATAL) << "Please add USE_TVM_OP = 1 as a compile flag to enable TVM-generated kernels "
"to support initial value";
}
#endif // MXNET_USE_TVM_OP
Stream<xpu>* s = ctx.get_stream<xpu>();
if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) {
using namespace mxnet_op;
Expand All @@ -269,14 +273,14 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place";
#if MXNET_USE_TVM_OP
// If boolean ndarray, use the kernel generated by TVM
if (inputs[0].type_flag_ == mshadow::kBool) {
if ((inputs[0].type_flag_ == mshadow::kBool \
|| inputs[0].type_flag_ == mshadow::kFloat16 || inputs[0].type_flag_ == mshadow::kFloat32 \
|| inputs[0].type_flag_ == mshadow::kFloat64 || inputs[0].type_flag_ == mshadow::kInt8 \
|| inputs[0].type_flag_ == mshadow::kInt32 || inputs[0].type_flag_ == mshadow::kInt64)
&& std::is_same<reducer, mshadow_op::sum>::value) {
std::string reducer_name;
if (std::is_same<reducer, mshadow_op::sum>::value) {
reducer_name = "sum";
} else {
LOG(FATAL) << "Only reduce op: `sum` is supported for boolean ndarrays";
}
TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name);
reducer_name = "sum";
TVMOpReduce(ctx, inputs[0], param.axis, param.initial, outputs[0], req[0], reducer_name);
if (normalize) {
using namespace mshadow::expr;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
Expand Down
39 changes: 37 additions & 2 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
*/

#if MXNET_USE_TVM_OP
#include <tvm/runtime/packed_func.h>
#include "../tvmop/op_module.h"
#include "../tvmop/op_module.h"
#endif // MXNET_USE_TVM_OP

Expand Down Expand Up @@ -77,6 +79,7 @@ TBlob PrependAxes(const TBlob& src, const int dst_ndim);
void TVMOpReduce(const OpContext& ctx,
const TBlob& input,
const dmlc::optional<mxnet::Tuple<int>>& axis,
const dmlc::optional<double> initial,
const TBlob& output,
const OpReqType req,
const std::string& reducer_name) {
Expand Down Expand Up @@ -119,8 +122,39 @@ void TVMOpReduce(const OpContext& ctx,
func_name << reducer_name << "_"
<< (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kCPU ? "cpu" : "gpu")
<< "reduce1st_dim_" << reduce1st_dim
<< "req_" << (req == kWriteTo ? "kWriteTo" : "kAddTo");
tvm::runtime::TVMOpModule::Get()->Call(func_name.str(), ctx, {input_tvm, output_tvm, output_tvm});
<< "req_" << (req == kWriteTo ? "kWriteTo" : "kAddTo")
<< "initial_" << (initial.has_value() ? "True" : "False");

if (initial.has_value()) {
std::vector<int> type_codes;
std::vector<TVMValue> values;
const size_t num_args = 4; // initial scalar
type_codes.resize(num_args);
values.resize(num_args);

// input tensor setup
type_codes[0] = kTVMDLTensorHandle;
values[0].v_handle = const_cast<DLTensor*>(&(input_tvm.dltensor()));

// scalar param setup
type_codes[1] = kDLFloat;
values[1].v_float64 = initial.value();

// output tensor setup
type_codes[2] = kTVMDLTensorHandle;
values[2].v_handle = const_cast<DLTensor*>(&(output_tvm.dltensor()));

// output tensor setup
type_codes[3] = kTVMDLTensorHandle;
values[3].v_handle = const_cast<DLTensor*>(&(output_tvm.dltensor()));

tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], 4);
tvm::runtime::TVMOpModule::Get()->CallEx(func_name.str(), ctx, \
{input_tvm, output_tvm, output_tvm}, tvm_args);
} else {
tvm::runtime::TVMOpModule::Get()->Call(func_name.str(), ctx, \
{input_tvm, output_tvm, output_tvm});
}
#else
LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag to enable TVM-generated kernels.";
#endif // MXNET_USE_TVM_OP
Expand All @@ -138,6 +172,7 @@ NNVM_REGISTER_OP(_np_sum)
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_argument("init", "double", "Initial scalar input")
.add_arguments(NumpyReduceAxesParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesCompute<cpu, mshadow_op::sum, true>)
.set_attr<FResourceRequest>("FResourceRequest",
Expand Down
Loading