Skip to content

Commit 80aa299

Browse files
committed
TST: Calculate RMS and diff image in C++
The current implementation is not slow, but uses a lot of memory per image. In `compare_images`, we have: - one actual and one expected image as uint8 (2×image) - both converted to int16 (though original is thrown away) (4×) which adds up to 4× the image allocated in this function. Then it calls `calculate_rms`, which has: - a difference between them as int16 (2×) - the difference cast to 64-bit float (8×) - the square of the difference as 64-bit float (though possibly the original difference was thrown away) (8×) which at its peak has 16× the image allocated in parallel. If the RMS is over the desired tolerance, then `save_diff_image` is called, which: - loads the actual and expected images _again_ as uint8 (2× image) - converts both to 64-bit float (throwing away the original) (16×) - calculates the difference (8×) - calculates the absolute value (8×) - multiples that by 10 (in-place, so no allocation) - clips to 0-255 (8×) - casts to uint8 (1×) which at peak uses 32× the image. So at their peak, `compare_images`→`calculate_rms` will have 20× the image allocated, and then `compare_images`→`save_diff_image` will have 36× the image allocated. This is generally not a problem, but on resource-constrained places like WASM, it can sometimes run out of memory just in `calculate_rms`. This implementation in C++ always allocates the diff image, even when not needed, but doesn't have all the temporaries, so it's a maximum of 3× the image size (plus a few scalar temporaries).
1 parent 2c1ec43 commit 80aa299

File tree

3 files changed

+110
-9
lines changed

3 files changed

+110
-9
lines changed

lib/matplotlib/testing/compare.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from PIL import Image
2020

2121
import matplotlib as mpl
22-
from matplotlib import cbook
22+
from matplotlib import cbook, _image
2323
from matplotlib.testing.exceptions import ImageComparisonFailure
2424

2525
_log = logging.getLogger(__name__)
@@ -412,7 +412,7 @@ def compare_images(expected, actual, tol, in_decorator=False):
412412
413413
The two given filenames may point to files which are convertible to
414414
PNG via the `!converter` dictionary. The underlying RMS is calculated
415-
with the `.calculate_rms` function.
415+
in a similar way to the `.calculate_rms` function.
416416
417417
Parameters
418418
----------
@@ -483,17 +483,12 @@ def compare_images(expected, actual, tol, in_decorator=False):
483483
if np.array_equal(expected_image, actual_image):
484484
return None
485485

486-
# convert to signed integers, so that the images can be subtracted without
487-
# overflow
488-
expected_image = expected_image.astype(np.int16)
489-
actual_image = actual_image.astype(np.int16)
490-
491-
rms = calculate_rms(expected_image, actual_image)
486+
rms, abs_diff = _image.calculate_rms_and_diff(expected_image, actual_image)
492487

493488
if rms <= tol:
494489
return None
495490

496-
save_diff_image(expected, actual, diff_image)
491+
Image.fromarray(abs_diff).save(diff_image, format="png")
497492

498493
results = dict(rms=rms, expected=str(expected),
499494
actual=str(actual), diff=str(diff_image), tol=tol)

lib/matplotlib/tests/test_compare_images.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from pathlib import Path
22
import shutil
33

4+
import numpy as np
45
import pytest
56
from pytest import approx
67

8+
from matplotlib._image import calculate_rms_and_diff
79
from matplotlib.testing.compare import compare_images
810
from matplotlib.testing.decorators import _image_directories
11+
from matplotlib.testing.exceptions import ImageComparisonFailure
912

1013

1114
# Tests of the image comparison algorithm.
@@ -71,3 +74,27 @@ def test_image_comparison_expect_rms(im1, im2, tol, expect_rms, tmp_path,
7174
else:
7275
assert results is not None
7376
assert results['rms'] == approx(expect_rms, abs=1e-4)
77+
78+
79+
def test_invalid_input():
80+
img = np.zeros((16, 16, 4), dtype=np.uint8)
81+
82+
with pytest.raises(ImageComparisonFailure,
83+
match='must be 3-dimensional, but is 2-dimensional'):
84+
calculate_rms_and_diff(img[:, :, 0], img)
85+
with pytest.raises(ImageComparisonFailure,
86+
match='must be 3-dimensional, but is 5-dimensional'):
87+
calculate_rms_and_diff(img, img[:, :, :, np.newaxis, np.newaxis])
88+
with pytest.raises(ImageComparisonFailure,
89+
match='must be RGB or RGBA but has depth 2'):
90+
calculate_rms_and_diff(img[:, :, :2], img)
91+
92+
with pytest.raises(ImageComparisonFailure,
93+
match=r'expected size: \(16, 16, 4\) actual size \(8, 16, 4\)'):
94+
calculate_rms_and_diff(img, img[:8, :, :])
95+
with pytest.raises(ImageComparisonFailure,
96+
match=r'expected size: \(16, 16, 4\) actual size \(16, 6, 4\)'):
97+
calculate_rms_and_diff(img, img[:, :6, :])
98+
with pytest.raises(ImageComparisonFailure,
99+
match=r'expected size: \(16, 16, 4\) actual size \(16, 16, 3\)'):
100+
calculate_rms_and_diff(img, img[:, :, :3])

src/_image_wrapper.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <pybind11/pybind11.h>
22
#include <pybind11/numpy.h>
33

4+
#include <algorithm>
5+
46
#include "_image_resample.h"
57
#include "py_converters.h"
68

@@ -202,6 +204,80 @@ image_resample(py::array input_array,
202204
}
203205

204206

207+
// This is used by matplotlib.testing.compare to calculate RMS and a difference image.
208+
static py::tuple
209+
calculate_rms_and_diff(py::array_t<unsigned char> expected_image,
210+
py::array_t<unsigned char> actual_image)
211+
{
212+
for (const auto & [image, name] : {std::pair{expected_image, "Expected"},
213+
std::pair{actual_image, "Actual"}})
214+
{
215+
if (image.ndim() != 3) {
216+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
217+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
218+
py::set_error(
219+
ImageComparisonFailure,
220+
"{name} image must be 3-dimensional, but is {ndim}-dimensional"_s.format(
221+
"name"_a=name, "ndim"_a=image.ndim()));
222+
throw py::error_already_set();
223+
}
224+
}
225+
226+
auto height = expected_image.shape(0);
227+
auto width = expected_image.shape(1);
228+
auto depth = expected_image.shape(2);
229+
230+
if (depth != 3 && depth != 4) {
231+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
232+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
233+
py::set_error(
234+
ImageComparisonFailure,
235+
"Image must be RGB or RGBA but has depth {depth}"_s.format(
236+
"depth"_a=depth));
237+
throw py::error_already_set();
238+
}
239+
240+
if (height != actual_image.shape(0) || width != actual_image.shape(1) ||
241+
depth != actual_image.shape(2)) {
242+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
243+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
244+
py::set_error(
245+
ImageComparisonFailure,
246+
"Image sizes do not match expected size: {expected_image.shape} "_s
247+
"actual size {actual_image.shape}"_s.format(
248+
"expected_image"_a=expected_image, "actual_image"_a=actual_image));
249+
throw py::error_already_set();
250+
}
251+
auto expected = expected_image.unchecked<3>();
252+
auto actual = actual_image.unchecked<3>();
253+
254+
py::ssize_t diff_dims[3] = {height, width, 3};
255+
py::array_t<unsigned char> diff_image(diff_dims);
256+
auto diff = diff_image.mutable_unchecked<3>();
257+
258+
double total = 0.0;
259+
for (auto i = 0; i < height; i++) {
260+
for (auto j = 0; j < width; j++) {
261+
for (auto k = 0; k < depth; k++) {
262+
auto pixel_diff = static_cast<double>(expected(i, j, k)) -
263+
static_cast<double>(actual(i, j, k));
264+
265+
total += pixel_diff*pixel_diff;
266+
267+
if (k != 3) { // Hard-code a fully solid alpha channel by omitting it.
268+
diff(i, j, k) = static_cast<unsigned char>(std::clamp(
269+
abs(pixel_diff) * 10, // Expand differences in luminance domain.
270+
0.0, 255.0));
271+
}
272+
}
273+
}
274+
}
275+
total = total / (width * height * depth);
276+
277+
return py::make_tuple(sqrt(total), diff_image);
278+
}
279+
280+
205281
PYBIND11_MODULE(_image, m, py::mod_gil_not_used())
206282
{
207283
py::enum_<interpolation_e>(m, "_InterpolationType")
@@ -234,4 +310,7 @@ PYBIND11_MODULE(_image, m, py::mod_gil_not_used())
234310
"norm"_a = false,
235311
"radius"_a = 1,
236312
image_resample__doc__);
313+
314+
m.def("calculate_rms_and_diff", &calculate_rms_and_diff,
315+
"expected_image"_a, "actual_image"_a);
237316
}

0 commit comments

Comments
 (0)