Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
sanity fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly authored and Ubuntu committed Feb 19, 2020
1 parent 256ed79 commit f043142
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
#if !MXNET_USE_TVM_OP
if (param.initial.has_value()) {
LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag to enable TVM-generated kernels \
to support initial value";
LOG(FATAL) << "Please add USE_TVM_OP = 1 as a compile flag to enable TVM-generated kernels "
"to support initial value";
}
#endif // MXNET_USE_TVM_OP
Stream<xpu>* s = ctx.get_stream<xpu>();
Expand Down
10 changes: 6 additions & 4 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
*/

#if MXNET_USE_TVM_OP
#include "../tvmop/op_module.h"
#include <tvm/runtime/packed_func.h>
#include "../tvmop/op_module.h"
#include "../tvmop/op_module.h"
#endif // MXNET_USE_TVM_OP

#include "np_broadcast_reduce_op.h"
Expand Down Expand Up @@ -128,7 +128,7 @@ void TVMOpReduce(const OpContext& ctx,
if (initial.has_value()) {
std::vector<int> type_codes;
std::vector<TVMValue> values;
const size_t num_args = 4; // initial scalar
const size_t num_args = 4; // initial scalar
type_codes.resize(num_args);
values.resize(num_args);

Expand All @@ -149,9 +149,11 @@ void TVMOpReduce(const OpContext& ctx,
values[3].v_handle = const_cast<DLTensor*>(&(output_tvm.dltensor()));

tvm::runtime::TVMArgs tvm_args(&values[0], &type_codes[0], 4);
tvm::runtime::TVMOpModule::Get()->CallEx(func_name.str(), ctx, {input_tvm, output_tvm, output_tvm}, tvm_args);
tvm::runtime::TVMOpModule::Get()->CallEx(func_name.str(), ctx, \
{input_tvm, output_tvm, output_tvm}, tvm_args);
} else {
tvm::runtime::TVMOpModule::Get()->Call(func_name.str(), ctx, {input_tvm, output_tvm, output_tvm});
tvm::runtime::TVMOpModule::Get()->Call(func_name.str(), ctx, \
{input_tvm, output_tvm, output_tvm});
}
#else
LOG(FATAL) << "Please add USE_TVM_OP=1 as a compile flag to enable TVM-generated kernels.";
Expand Down

0 comments on commit f043142

Please sign in to comment.