Skip to content

An efficient point/patch-wise "distance" #918

@johnnychen94

Description

@johnnychen94

This need arises when I'm implementing a faster WNNM[1] image denoiser(still WIP, 7x boost at the time of writing). I'm not sure how broadly this can be used, so I'd like to open an issue here to get some early feedback.

Background

After the nonlocal mean filters[2] and BM3D[3], it becomes a consensus that block matching similar patches into a group and doing denoise work at a patch-level is more performant than at the pixel-level. A typical block-matching denoising workflow is as follows:

# for simplicity, I didn't put border condition into this and following code snippets

function block_matching_denoiser(f, img)
    out = fill(0, axes(img))
    W = fill(0, axes(img))
    for p in CartesianIndices(img)
        patch_p = @view img[p-r:p+r]
        matched_patches = block_matching(img, patch_p; num_patches=80)

        # input: m*n*N
        # output: m*n
        patch_out = f(img, patch_p, matched_patches)

        view(out, matches_patches) .+= patch_out
        view(W, matches_patches) .+= 1
    end

    # weighted summation
    # typically called `patch2img` in MATLAB world
    out ./= W
    return out
end

The major difference between each algorithm lies in f where nonlocal mean uses a weighted mean, BM3D uses a sophisticated 1D+2D filter, and WNNM uses low-rank approximation (svd).

There are two computational bottlenecks of this algorithm: block_matching and f. In this issue, I'll just focus on block_matching. (I have ideas on how svd can be optimized for this very specific task but that's out of the scope here.)

Here's a demo of block-matching (from BM3D website)

A block-matching subroutine, in its naive form, is:

# `f(patch_p, patch_q)` measures how similar two patches are.

function block_matching(f, img, p; num_patches, patch_size, patch_stride, search_size)
    R = CartesianIndices(img)
    r = patch_size ÷ 2

    patch_p = @view img[p-r:p+r]
    _measure(q) = f(patch_p, view(img, q-r:q+r))
    qs = local_neighbor(p, patch_stride, search_size)
    dist = _measure.(qs) # the computation bottleneck
    
    matched_points = qs[sortperm(dist)[1:num_patches]]
    matched_patches = [q-r:q+r for q in matched_points]
end

Block matching for one pixel seems fine, but this is only a subroutine in the outer loop, so there is a lot of repeated computation involved. search_size is usually set a relatively small number for exactly this reason (to reduce computation and memory requirement). This issue is on strategies to remove this unnecessary computation in a memory-friendly way.

The existing implementation has tweaked in a way that can hardly be reused by other codes; deeply coupled with the for loop, which is not good for either performance (e.g., how to multiple-threads things will be a big challenge) or code reuse. This issue also tries to provide an easy-to-use interface with all optimization transparent to users.

Proposal

Naively, to represent the complete pointwise distance result of array A and B, it needs an array of axes (axes(A)..., axes(B)...). This is an extremely large array even for two moderate-size matrices and we can not create it directly.

What I have in mind is to create a new array type, which holds the following three properties:

  • lazy evaluation: only store needed information during initialization and all computation is done until required
  • cached: store a subset of computation results
  • (optional) symmetric: for some common operations like Euclidean(), the result is symmetric (f(x,y)=f(y,x)) and thus the cache can be used more efficiently.

With this in mind, the block matching denoising can be computed quite easy and efficient:

patch_size = (7, 7)
r_d = CartesianIndex((patch_size..., patch_size...) ÷ 2)

# here we share both abs2 and ssd results globally
point_distances = pointwise((x,y)->abs2(x-y), A)
patch_distances = pointwise(A) do p, q
    center = CartesianIndex(p.I..., q.I...)
    sum(view(point_distances, center-r_d:center+r_d))
end

# a block matching thus becomes a trivial partial sort on patch_distances

Here pointwise returns the array type that I want to add. Because this type mimics the naive 4d array concept, it is quite intuitive to use.

The design of this array type completely depends on how results are cached. There are two caching strategies that sound promising to me.

static window cache

This is exactly the array abstraction of the existing block matching code in those all algorithms; they only tries to find similar patches in a non-local neighborhood (not in the entire global image). Each pixel has a cache block of size patch_stride[k]*prod(window_size)/window_size[k] and that stores the results in its non-local neighborhood.

More details are needed on how indices are computed (it requires some mind effort to work it right) but this is the idea.

Benefit:

  • The cache is statically coded into the structure, so we could get a quite stable performance even with multi-threads enabled.
* cache is used very efficiently in terms of capacity: the index can be computed at runtime.

Drawback:

  • cache size is proportional to window_size, which is still a limitation. (set window_size=size(img) is still very large)

FIFO queueing cache

This simple caching strategy can be quite useful because, in many cases, we are iterating over the image sequentially. Whatever removed from the cache is very likely to never be used anymore in future iteration.

Benefit:

  • Because of its dynamically, we don't need to specify window_size, and the cache size can be arbitrary.

Drawback:

  • the cache capacity is only 1/(ndims(A)+1) because we also need to store the index of the pixel.
  • a racing condition might hit and thus cause cache-missing in a multi-threads setting unless we allocate one such array for each thread.

others

LRU and other sophisticated cache strategies might not apply to this task because of the additional overhead; each computation might just take hundreds of ns or several μs. I don't have a good estimation yet.

Plans

I plan to put this in a new package. I plan to implement the "static window cache" first as a faithful reimplementation just to see how fast I can get WNNM to.

References

[1] Gu, S., Zhang, L., Zuo, W., & Feng, X. (2014). Weighted nuclear norm minimization with application to image denoising. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 2862-2869).

[2] Buades, A., Coll, B., & Morel, J. M. (2005, June). A non-local algorithm for image denoising. In 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05) (Vol. 2, pp. 60-65). IEEE.

[3] Dabov, K., Foi, A., Katkovnik, V., & Egiazarian, K. (2007). Image denoising by sparse 3-D transform-domain collaborative filtering. IEEE Transactions on image processing, 16(8), 2080-2095.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions