Skip to content

Commit

Permalink
relax rtol/atol checks on some onnx tests (#2403)
Browse files Browse the repository at this point in the history
relax the error constraints on these tests due to likely
FP compuation accuracy issues.
  • Loading branch information
yangchen-MS authored and tqchen committed Jan 9, 2019
1 parent e36265b commit f607d46
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions nnvm/tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_slice():
_test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1))

def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs, rtol=1e-7, atol=1e-7):
indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
outdata = outfunc(indata, **npargs)

Expand All @@ -290,7 +290,7 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)

tvm.testing.assert_allclose(outdata, tvm_out)
tvm.testing.assert_allclose(outdata, tvm_out, rtol=rtol, atol=atol)

def test_floor():
_test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {})
Expand Down Expand Up @@ -863,7 +863,7 @@ def test_binary_ops():
dtype = "float32"
out_shape = in_shape

def verify_binary_ops(op, x, y, out_np, broadcast=None):
def verify_binary_ops(op, x, y, out_np, broadcast=None, rtol=1e-7, atol=1e-7):
if broadcast is None:
z = helper.make_node(op, ['in1', 'in2'], ['out'])
else:
Expand All @@ -879,7 +879,7 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None):
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, y], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)

x = np.random.uniform(size=in_shape).astype(dtype)
y = np.random.uniform(size=in_shape).astype(dtype)
Expand All @@ -890,16 +890,16 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None):
verify_binary_ops("Sub", x, z, x - z, broadcast=True)
verify_binary_ops("Mul",x, y, x * y, broadcast=None)
verify_binary_ops("Mul", x, z, x * z, broadcast=True)
verify_binary_ops("Div", x, y, x / y, broadcast=None)
verify_binary_ops("Div", x, z, x / z, broadcast=True)
verify_binary_ops("Div", x, y, x / y, broadcast=None, rtol=1e-5, atol=1e-5)
verify_binary_ops("Div", x, z, x / z, broadcast=True, rtol=1e-5, atol=1e-5)
verify_binary_ops("Sum", x, y, x + y, broadcast=None)

def test_single_ops():
in_shape = (1, 2, 3, 3)
dtype = "float32"
out_shape = in_shape

def verify_single_ops(op, x, out_np):
def verify_single_ops(op, x, out_np, rtol=1e-7, atol=1e-7):
z = helper.make_node(op, ['in1'], ['out'])
graph = helper.make_graph([z],
'_test',
Expand All @@ -915,8 +915,8 @@ def verify_single_ops(op, x, out_np):
x = np.random.uniform(size=in_shape).astype(dtype)
verify_single_ops("Neg",x, -x)
verify_single_ops("Abs",x, np.abs(x))
verify_single_ops("Reciprocal",x, 1/x)
verify_single_ops("Sqrt",x, np.sqrt(x))
verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5)
verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Relu",x, np.maximum(x, 0))
verify_single_ops("Exp",x, np.exp(x))
verify_single_ops("Log",x, np.log(x))
Expand Down Expand Up @@ -1004,7 +1004,9 @@ def test_LogSoftmax():
{},
'float32',
'LogSoftmax',
{'axis': 1})
{'axis': 1},
rtol=1e-5,
atol=1e-5)

if __name__ == '__main__':
# verify_super_resolution_example()
Expand Down

0 comments on commit f607d46

Please sign in to comment.