Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different behavior of bilinear interpolation between resize_images and ONNX Upsample #147

Open
syoyo opened this issue Apr 17, 2019 · 1 comment

Comments

@syoyo
Copy link
Contributor

syoyo commented Apr 17, 2019

I have implemented exporting resize_images to ONNX by lowering it into Upsample ONNX op as did in unpooling_2D.

https://github.com/syoyo/onnx-chainer/tree/resize_images

and wanted to submit PR, but found a unit test fails due to different behavior of bilinear interpolation in Chainer and onnxruntime(and also TensorFlow/OpenCV which is a reference).

I use resize_images for pyramid pooling, thus need to embed resizing op into a model file.
So, I've tracked down this problem and here is a summary:

Input and output

Input shape: (2x2). [[64, 32], [64, 32]]
Output shape: (4x4) (Scaling factor is 2x)
Result: [64, ***, ***, 32] (Just check the first row)

========================================================= FAILURES ==========================================================
_______________________________________________ TestResizeImages.test_output ________________________________________________

self = <tests.functions_tests.test_arrays.TestResizeImages testMethod=test_output>

    def test_output(self):
    
        # FIXME(syoyo): Currently the test will fail due to different behavior
        # of bilinear interpolation between Chainer and onnxruntime.
        #
        # Currently Chainer will give [64, 53.333336, 42.666668, 32]
        # (same result with tensorflow r1.13.1 with `align_corners=True`),
        # while onnxruntime gives [64, 48, 32, 32]
        # (same result with tensorflow r1.13.1 with `align_corners=False`)
        #
        # Even though, expected bevhavior will be [64, 54, 40, 32].
        # (cv2.resize and tensorflow master(r1.14 or r2.0) after this fix:
        #  https://github.com/tensorflow/tensorflow/issues/6720)
    
        # TODO(hamaji): onnxruntime does not support Upsample-9 yet.
        # https://github.com/chainer/onnx-chainer/issues/111
>       self.expect(self.model, self.x, name='resize_images', skip_opset_version=[9])

tests/functions_tests/test_arrays.py:240: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/helper.py:101: in expect
    self.check_out_values(test_path, input_names=graph_input_names)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

test_path = 'out/opset7/test_resize_images', input_names = ['Input_0']

    def check_model_expect(test_path, input_names=None):
        if not ONNXRUNTIME_AVAILABLE:
            raise ImportError('ONNX Runtime is not found on checking module.')
    
        model_path = os.path.join(test_path, 'model.onnx')
        with open(model_path, 'rb') as f:
            onnx_model = onnx.load_model(f)
        sess = rt.InferenceSession(onnx_model.SerializeToString())
        rt_input_names = [value.name for value in sess.get_inputs()]
        rt_output_names = [value.name for value in sess.get_outputs()]
    
        # To detect unexpected inputs created by exporter, check input names
        if input_names is not None:
            assert list(sorted(input_names)) == list(sorted(rt_input_names))
    
        test_data_sets = sorted([
            p for p in os.listdir(test_path) if p.startswith('test_data_set_')])
        for test_data in test_data_sets:
            test_data_path = os.path.join(test_path, test_data)
            assert os.path.isdir(test_data_path)
            inputs, outputs = load_test_data(
                test_data_path, rt_input_names, rt_output_names)
    
            rt_out = sess.run(list(outputs.keys()), inputs)
            for cy, my in zip(outputs.values(), rt_out):
>               np.testing.assert_allclose(cy, my, rtol=1e-5, atol=1e-5)
E               AssertionError: 
E               Not equal to tolerance rtol=1e-05, atol=1e-05
E               
E               Mismatch: 50%
E               Max absolute difference: 10.666668
E               Max relative difference: 0.33333337
E                x: array([[[[64.      , 53.333336, 42.666668, 32.      ],
E                        [64.      , 53.333336, 42.666668, 32.      ],
E                        [64.      , 53.333332, 42.666668, 32.      ],
E                        [64.      , 53.333336, 42.666668, 32.      ]]]], dtype=float32)
E                y: array([[[[64., 48., 32., 32.],
E                        [64., 48., 32., 32.],
E                        [64., 48., 32., 32.],
E                        [64., 48., 32., 32.]]]], dtype=float32)

onnx_chainer/testing/test_onnxruntime.py:62: AssertionError

Chainer(v5.4)

it results in [64, 53.333336, 42.666668, 32].
(can be reproducable by running tests/functions_tests/test_arrays.py)

onnxruntime(v0.3.0)

it results in [64, 48, 32, 32].
(can be reproducable by running tests/functions_tests/test_arrays.py)

TensorFlow

import tensorflow as tf
import numpy as np

print(tf.__version__)

# NCHW
img = np.array([[[[64, 32],
                  [64, 32]]]], np.float32)
print(img.shape)

# to NHWC
img = tf.transpose(img, [0, 2, 3, 1])
print(img.shape)

align_corners = False # or True
img = tf.image.resize_bilinear(img, size=[4, 4], align_corners=align_corners)

sess = tf.Session()
ret = sess.run(img)
print(ret)

cv2.resize(should be the ground truth)

import cv2
import numpy as np

width = 2
height = 2

img = np.array([[64.0, 64.0],
                [32.0, 32.0]], np.float32)

size = (4, 4)

resized_img = cv2.resize(img, size)

print(resized_img)
 
# => (64, 56, 40, 32). FYI interpolator is [0, 0.25, 0.75, 1.0]

Current Chainer's behavior is same with the result of TF r1.13.1 but there was a bug

tensorflow/tensorflow#6720

and recently there had a fix to TF's resize_images with this commit: tensorflow/tensorflow@371c96d

What I should do?

Correct behavior will be the one in cv2.resize and tf r2.0([64, 56, 40, 32]), so we are better to move towards this direction in the last(i.e. add a fix or implement new resize_images into Chainer), but for a while there will be some options:

  • Disable unit test for resize_images and send PR
  • Write custom ONNX op or compare with manually supplied exected values
  • Wait until onnx opset 10 support in onnx-chainer(Resize op. since Upsample will be deprecated in opset 10)

References

@disktnk
Copy link
Member

disktnk commented Apr 18, 2019

Thank you for detail report. I saw your commits master...syoyo:resize_images roughly and looks fine.

wanted to submit PR, but found a unit test fails due to different behavior of bilinear interpolation in Chainer and onnxruntime

Please void output value check between Chainer and ONNXRuntime, like this

def test_output(self):

    self.check_out_values = None  # skip output value check
    self.expect(self.model, self.x)

Test tool does not have "skip all output value check" option, so by adding self.check_out_values = None, skip it (hacky a little...), and only check output ONNX graph.

weiji14 added a commit to weiji14/deepbedmap that referenced this issue Jul 25, 2019
Doing away with the PixelShuffle (depth2space) upsampling layer that was quite prone to checkerboard artifacts. Replacing it with NearestNeighbour interpolation instead. Note that previously the Convolutional2D layer was before (pre) the Upsampling, and now it is placed after (post), and we've reduced the channels from 256 to 64, This chainer implementation of ours now follows the ESRGAN pytorch implementation as in https://github.com/xinntao/BasicSR/blob/85d7b14107a2705683b2568f77fe1c684e29f530/codes/models/modules/RRDBNet_arch.py very closely! Experiment training details logged at https://www.comet.ml/weiji14/deepbedmap/a3b87a2f383243c7a66d18d146aafe61 with an RMSE_test of 73.24, not great but qualitatively there's less of a checkerboard artifact.

Note that onnx-chainer 1.4.1 raises a warning when exporting the resize_images function that we are suppressing. Full UserWarning is as follows:" `resize_images` is mapped to `Upsampling` ONNX op with bilinear interpolation. Behavior of bilinear interpolation differs from each implementation. See the issue chainer/onnx-chainer#147 for details".
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants