Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Backport of #16827, #16791 and #16888 to 1.6 branch #16901

Merged
merged 3 commits into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Mixed precison binary op backward (use in) for numpy (#16791)
* mixed precison binary op backward

* reduce unix cpu runtime
  • Loading branch information
haojin2 authored and ptrendx committed Nov 25, 2019
commit 39821a43dbc2c8d4a27145f4c4107dbbfc2e4389
17 changes: 16 additions & 1 deletion src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
"FCompute<cpu>",
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::right,
mshadow_op::left>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
Expand Down
4 changes: 4 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ NNVM_REGISTER_OP(_npi_multiply)
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
#endif

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, mshadow_op::right,
mshadow_op::left>);

NNVM_REGISTER_OP(_npi_mod)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);

Expand Down
104 changes: 102 additions & 2 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_

#include <algorithm>
#include <vector>
#include <string>

Expand Down Expand Up @@ -391,11 +392,13 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
}

template<typename xpu, typename LOP, typename ROP>
void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);

Expand All @@ -406,7 +409,104 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
return;
}

PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
const TBlob& ograd = inputs[0];
const TBlob& lgrad = outputs[0];
const TBlob& rgrad = outputs[1];

if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
// If any of the inputs is a float, it's the same type as the output
// So 2 of the 3 tensors have the same data type
Stream<xpu> *s = ctx.get_stream<xpu>();
mxnet::TShape new_lshape, new_rshape, new_oshape;
using namespace broadcast;
const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, rgrad.shape_, ograd.shape_,
&new_lshape, &new_rshape, &new_oshape) != 0;

// Prepare all the temporary memory
size_t workspace_size_l = 0, workspace_size_r = 0;
TBlob temp_tblob; // The TBlob for casted input data
TBlob temp_igrad; // The TBlob for casted grad results
size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() : rgrad.Size();
Tensor<xpu, 1, char> workspace;

MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, {
workspace_size_l = ReduceWorkspaceSize<ndim, OType>(
s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
workspace_size_r = ReduceWorkspaceSize<ndim, OType>(
s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
});
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
size_t cast_tensor_size = tensor_size * sizeof(OType);
// Allocate the temporary memories now
Tensor<xpu, 1, char> temp_space =
ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(workspace_size + cast_tensor_size * 2), s);
// Tensor for temp_tblob
Tensor<xpu, 1, OType> temp_tblob_tensor(
reinterpret_cast<OType*>(temp_space.dptr_),
Shape1(tensor_size), s);
// Tensor for temp_igrad
Tensor<xpu, 1, OType> temp_igrad_tensor(
reinterpret_cast<OType*>(temp_space.dptr_) + tensor_size,
Shape1(tensor_size), s);
temp_tblob =
TBlob(temp_tblob_tensor)
.reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
temp_igrad =
TBlob(temp_igrad_tensor)
.reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
if (temp_igrad.Size() != 0) {
Kernel<set_zero, xpu>::Launch(s, temp_igrad.Size(), temp_igrad.dptr<OType>());
}
workspace =
Tensor<xpu, 1, char>(temp_space.dptr_ + 2 * cast_tensor_size, Shape1(workspace_size), s);
});
// Cast the input that does not have consistent type to temp_tblob
CastCompute<xpu>(
attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, {kWriteTo}, {temp_tblob});
if (!need_bc) {
if (lhs.type_flag_ != ograd.type_flag_) {
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad});
} else {
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad});
}
} else {
if (lhs.type_flag_ != ograd.type_flag_) {
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad},
workspace, new_lshape, new_rshape, new_oshape);
});
});
} else {
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad},
workspace, new_lshape, new_rshape, new_oshape);
});
});
}
}

// If both inputs are floating numbers, cast the igrad to the input that has
// the different data type
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
if (lhs.type_flag_ != ograd.type_flag_) {
CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad});
} else {
CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad});
}
}
} else {
// Case where both inputs are integer types, should not even do
// backward computation for this case.
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
}
}

} // namespace op
Expand Down
26 changes: 26 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,32 @@ BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mshadow::Tensor<xpu, 1, char>& workspace,
const mxnet::TShape& new_lshape,
const mxnet::TShape& new_rshape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob lgrad = outputs[0].reshape(new_lshape);
const TBlob rgrad = outputs[1].reshape(new_rshape);
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob lhs = inputs[1].reshape(new_lshape);
const TBlob rhs = inputs[2].reshape(new_rshape);
if (ograd.Size() != 0) {
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(s, lgrad, req[0], workspace,
ograd, lhs, rhs);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(s, rgrad, req[1], workspace,
ograd, lhs, rhs);
}
}

template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
req[0] != kWriteInplace) {
if ((outputs[0].type_flag_ != inputs[0].type_flag_ ||
req[0] != kWriteInplace) && outputs[0].Size() != 0) {
Assign(out, req[0], tcast<DstDType>(data));
}
});
Expand Down
20 changes: 12 additions & 8 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
@with_seed()
@use_np
def test_np_mixed_precision_binary_funcs():
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype):
itypes = [np.bool, np.int8, np.int32, np.int64]
ftypes = [np.float16, np.float32, np.float64]
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, ltype, rtype):
class TestMixedBinary(HybridBlock):
def __init__(self, func):
super(TestMixedBinary, self).__init__()
Expand Down Expand Up @@ -1718,13 +1720,15 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
use_broadcast=False, equal_nan=True)

funcs = {
'add': (-1.0, 1.0),
'subtract': (-1.0, 1.0),
'multiply': (-1.0, 1.0),
'add': (-1.0, 1.0, None, None),
'subtract': (-1.0, 1.0, None, None),
'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape),
lambda y, x1, x2: _np.broadcast_to(x1, y.shape))
}

shape_pairs = [((3, 2), (3, 2)),
((3, 2), (3, 1)),
((3, 0), (3, 0)),
((3, 1), (3, 0)),
((0, 2), (1, 2)),
((2, 3, 4), (3, 1)),
Expand All @@ -1734,16 +1738,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
itypes = [np.bool, np.int8, np.int32, np.int64]
ftypes = [np.float16, np.float32, np.float64]
for func, func_data in funcs.items():
low, high = func_data
low, high, lgrad, rgrad = func_data
for lshape, rshape in shape_pairs:
for type1, type2 in itertools.product(itypes, ftypes):
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type2, type1)

for type1, type2 in itertools.product(ftypes, ftypes):
if type1 == type2:
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)


@with_seed()
Expand Down