Skip to content

Commit d15429c

Browse files
icemelonwweic
authored andcommitted
[Fix] Fix a few bugs when dtype is fp16 (apache#4088)
* Fix layer norm for fp16 * [Fix] Fix arange for fp16 * [Fix] Fix mxnet frontend for fp16 * [Fix] Fix arange for fp16 * remove comments * x * fix nnvm
1 parent 5b2cfaa commit d15429c

File tree

4 files changed

+55
-15
lines changed

4 files changed

+55
-15
lines changed

python/tvm/relay/frontend/mxnet.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,12 +615,17 @@ def _mx_arange(inputs, attrs):
615615
if attrs.get_int("repeat", 1) != 1:
616616
raise tvm.error.OpAttributeUnimplemented(
617617
'Attribute "repeat" is not supported in operator arange.')
618-
new_attrs = {}
619-
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
618+
dtype = attrs.get_str("dtype", "float32")
620619
stop = attrs.get_str("stop", "None")
621-
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop))
622-
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
623-
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
620+
if stop == "None":
621+
stop = None
622+
else:
623+
stop = _expr.const(float(stop), dtype=dtype)
624+
new_attrs = {}
625+
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0), dtype=dtype)
626+
new_attrs["stop"] = stop
627+
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0), dtype=dtype)
628+
new_attrs["dtype"] = dtype
624629
return _op.arange(**new_attrs)
625630

626631

@@ -863,7 +868,8 @@ def _mx_contrib_div_sqrt_dim(inputs, _):
863868
assert len(inputs) == 1
864869
ndim = len(_infer_type(inputs[0]).checked_type.shape)
865870
dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
866-
sqrt_dim = _op.sqrt(dim.astype('float32'))
871+
dtype = _infer_type(inputs[0]).checked_type.dtype
872+
sqrt_dim = _op.sqrt(dim.astype(dtype))
867873
out = inputs[0] / sqrt_dim
868874
return out
869875

python/tvm/relay/frontend/nnvm_common.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .. import expr as _expr
2222
from .. import op as _op
2323
from .common import get_relay_op
24+
from .common import infer_type as _infer_type
2425

2526
def _warn_not_used(attr, op='nnvm'):
2627
import warnings
@@ -123,20 +124,22 @@ def _elemwise_sum(inputs, _, _dtype='float32'):
123124

124125

125126
def _binop_scalar(new_op):
126-
def _impl(inputs, attrs, odtype='float32'):
127+
def _impl(inputs, attrs, odtype=None):
127128
assert len(inputs) == 1
128129
scalar = attrs.get_float("scalar")
129-
# Note: binary scalar only works for float op for now
130+
if odtype is None:
131+
odtype = _infer_type(inputs[0]).checked_type.dtype
130132
scalar = _expr.const(scalar, dtype=odtype)
131133
return new_op(inputs[0], scalar)
132134
return _impl
133135

134136

135137
def _rbinop_scalar(new_op):
136-
def _impl(inputs, attrs, odtype='float32'):
138+
def _impl(inputs, attrs, odtype=None):
137139
assert len(inputs) == 1
138140
scalar = attrs.get_float("scalar")
139-
# Note: binary scalar only works for float op for now
141+
if odtype is None:
142+
odtype = _infer_type(inputs[0]).checked_type.dtype
140143
scalar = _expr.const(scalar, dtype=odtype)
141144
return new_op(scalar, inputs[0])
142145
return _impl

src/relay/op/tensor/transform.cc

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/expr_operator.h>
2929
#include <tvm/ir.h>
3030
#include <tvm/data_layout.h>
31+
#include <tvm/runtime/packed_func.h>
3132
#include <topi/transform.h>
3233
#include <topi/elemwise.h>
3334
#include <topi/broadcast.h>
@@ -1139,11 +1140,41 @@ and type as the input array.
11391140
TVM_REGISTER_NODE_TYPE(ArangeAttrs);
11401141

11411142
double ToScalar(const runtime::NDArray& array) {
1142-
if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) {
1143-
return reinterpret_cast<int32_t*>(array->data)[0];
1144-
} else {
1145-
return reinterpret_cast<float*>(array->data)[0];
1143+
if (array->dtype.code == kDLInt) {
1144+
if (array->dtype.bits == 8) {
1145+
return reinterpret_cast<int8_t*>(array->data)[0];
1146+
} else if (array->dtype.bits == 16) {
1147+
return reinterpret_cast<int16_t*>(array->data)[0];
1148+
} else if (array->dtype.bits == 32) {
1149+
return reinterpret_cast<int32_t*>(array->data)[0];
1150+
} else if (array->dtype.bits == 64) {
1151+
return reinterpret_cast<int64_t*>(array->data)[0];
1152+
}
1153+
} else if (array->dtype.code == kDLUInt) {
1154+
if (array->dtype.bits == 8) {
1155+
return reinterpret_cast<uint8_t*>(array->data)[0];
1156+
} else if (array->dtype.bits == 16) {
1157+
return reinterpret_cast<uint16_t*>(array->data)[0];
1158+
} else if (array->dtype.bits == 32) {
1159+
return reinterpret_cast<uint32_t*>(array->data)[0];
1160+
} else if (array->dtype.bits == 64) {
1161+
return reinterpret_cast<uint64_t*>(array->data)[0];
1162+
}
1163+
} else if (array->dtype.code == kDLFloat) {
1164+
#if (__ARM_FP16_FORMAT_IEEE == 1)
1165+
if (array->dtype.bits == 16) {
1166+
return reinterpret_cast<__fp16*>(array->data)[0];
1167+
}
1168+
#endif
1169+
if (array->dtype.bits == 32) {
1170+
return reinterpret_cast<float*>(array->data)[0];
1171+
} else if (array->dtype.bits == 64) {
1172+
return reinterpret_cast<double*>(array->data)[0];
1173+
}
11461174
}
1175+
LOG(FATAL) << "Unknown data type: " << tvm::runtime::TVMType2String(array->dtype);
1176+
// make compiler happy
1177+
return -std::numeric_limits<double>::infinity();
11471178
}
11481179

11491180
bool ArangeRel(const Array<Type>& types,

src/relay/pass/simplify_inference.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
7575
const auto param = attrs.as<LayerNormAttrs>();
7676
CHECK(param);
7777

78-
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
78+
Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
7979
Expr mean = Mean(data, {param->axis}, true, false);
8080
Expr var = Variance(data, mean, {param->axis}, true, false);
8181
Expr denom = Sqrt(Add(var, epsilon));

0 commit comments

Comments
 (0)