Skip to content

Commit

Permalink
feat(mvgd): add analytical solution to MVGD transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnec committed Aug 31, 2020
1 parent 37556ac commit 12e3912
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 88 deletions.
2 changes: 1 addition & 1 deletion color_matcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

__version__ = '0.2.6'
__version__ = '0.3.0'

from .top_level import ColorMatcher
from .hist_matcher import HistogramMatcher
Expand Down
85 changes: 52 additions & 33 deletions color_matcher/mvgd_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,26 @@ class TransferMVGD(MatcherBaseclass):
def __init__(self, *args, **kwargs):
super(TransferMVGD, self).__init__(*args, **kwargs)

self._fun = kwargs['fun'] if 'fun' in kwargs else self.mkl_solver
self._fun_dict = {'analytical': self.analytical_solver, 'mkl': self.mkl_solver}
self._fun_name = kwargs['fun'] if 'fun' in kwargs else 'mkl' # use MKL as default
self._fun_call = self._fun_dict[self._fun_name] if self._fun_name in self._fun_dict else self.analytical_solver

# initialize variables
self.r = np.reshape(self._src, [-1, self._src.shape[2]])
self.z = np.reshape(self._ref, [-1, self._ref.shape[2]])
self.cov_r, self.cov_z = np.cov(self.r.T), np.cov(self.z.T)
self.mu_r, self.mu_z = np.mean(self.r, axis=0), np.mean(self.z, axis=0)

def _init_vars(self):

self.r = np.reshape(self._src, [-1, self._src.shape[2]])
self.z = np.reshape(self._ref, [-1, self._ref.shape[2]])

self.cov_r = np.cov(self.r.T)
self.cov_z = np.cov(self.z.T)

self.mu_r = np.mean(self.r, axis=0)
self.mu_z = np.mean(self.z, axis=0)

def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: FunctionType = None) -> np.ndarray:
"""
Expand All @@ -41,13 +60,13 @@ def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: Function
:param src: Source image that requires transfer
:param ref: Palette image which serves as reference
:param fun: optional argument to pass a transfer function to solve for covariance matrices
:param t_r: Resulting image after the mapping
:param res: Resulting image after the mapping
:type src: :class:`~numpy:numpy.ndarray`
:type ref: :class:`~numpy:numpy.ndarray`
:type t_r: :class:`~numpy:numpy.ndarray`
:type res: :class:`~numpy:numpy.ndarray`
:return: **t_r**
:return: **res**
:rtype: np.ndarray
"""
Expand All @@ -59,54 +78,54 @@ def transfer(self, src: np.ndarray = None, ref: np.ndarray = None, fun: Function
# check if three color channels are provided
self.validate_color_chs()

# set solver function for transfer matrix (default is MKL)
self._fun = fun if fun is not None else self._fun
# re-initialize variables to account for change in src and ref when passed to self.transfer()
self._init_vars()

r = np.reshape(src, [-1, src.shape[2]])
z = np.reshape(ref, [-1, ref.shape[2]])
# set solver function for transfer matrix
self._fun_call = fun if fun is FunctionType else self._fun_call

cov_r = np.cov(r.T)
cov_z = np.cov(z.T)
# compute transfer matrix
transfer_mat = self._fun_call()

transfer_mat = self._fun(cov_r, cov_z)
# transfer the intensity distributions
res = np.dot((self.r - self.mu_r), transfer_mat) + self.mu_z
res = np.reshape(res, self._src.shape)

mu_r = np.mean(r, axis=0)
mu_z = np.mean(z, axis=0)
return res

t_r = np.dot((r - mu_r), transfer_mat) + mu_z
t_r = np.reshape(t_r, src.shape)

return t_r

@staticmethod
def mkl_solver(cov_r: np.ndarray, cov_z: np.ndarray):
def mkl_solver(self):
"""
This function computes the transfer matrix based on the Monge-Kantorovich linearization.
:param cov_r: Covariance matrix of source image
:param cov_z: Covariance matrix of reference image
:param transfer_mat: Transfer matrix
:type cov_r: :class:`~numpy:numpy.ndarray`
:type cov_z: :class:`~numpy:numpy.ndarray`
:return: **transfer_mat**: Transfer matrix
:type transfer_mat: :class:`~numpy:numpy.ndarray`
:return: **transfer_mat**
:rtype: np.ndarray
"""

[Da2, Ua] = np.linalg.eig(cov_r)
[Da2, Ua] = np.linalg.eig(self.cov_r)
Ua = np.array([Ua[:, 2] * -1, Ua[:, 1], Ua[:, 0] * -1]).T
Da2[Da2 < 0] = 0
Da = np.diag(np.sqrt(Da2[::-1]))
C = np.dot(Da, np.dot(Ua.T, np.dot(cov_z, np.dot(Ua, Da))))
C = np.dot(Da, np.dot(Ua.T, np.dot(self.cov_z, np.dot(Ua, Da))))
[Dc2, Uc] = np.linalg.eig(C)
Dc2[Dc2 < 0] = 0
Dc = np.diag(np.sqrt(Dc2))
Da_inv = np.diag(1. / (np.diag(Da + np.spacing(1))))

transfer_mat = np.dot(Ua, np.dot(Da_inv, np.dot(Uc, np.dot(Dc, np.dot(Uc.T, np.dot(Da_inv, Ua.T))))))
return np.dot(Ua, np.dot(Da_inv, np.dot(Uc, np.dot(Dc, np.dot(Uc.T, np.dot(Da_inv, Ua.T))))))

def analytical_solver(self) -> np.ndarray:
"""
An analytical solution to the linear equation system of MVGDs.
:return: **transfer_mat**: Transfer matrix
:type transfer_mat: :class:`~numpy:numpy.ndarray`
:rtype: np.ndarray
"""

cov_r_inv = np.linalg.inv(self.cov_r)
cov_z_inv = np.linalg.inv(self.cov_z)

return transfer_mat
return np.dot(np.dot(np.linalg.pinv(np.dot(self.z-self.mu_z, cov_z_inv)), self.r-self.mu_r), cov_r_inv).T
2 changes: 1 addition & 1 deletion docs/build/html/.buildinfo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 8ec35f5b9042021743edef2b406e4e61
config: 56bf47be1cda60fc3fb79ed2a8f61ee6
tags: 645f666f9bcd5a90fca523b33c5a78b7
2 changes: 1 addition & 1 deletion docs/build/html/_static/documentation_options.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
VERSION: '0.2.6',
VERSION: '0.3.0',
LANGUAGE: 'None',
COLLAPSE_INDEX: false,
BUILDER: 'html',
Expand Down
41 changes: 24 additions & 17 deletions docs/build/html/apidoc.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>API documentation &#8212; color-matcher 0.2.6 documentation</title>
<title>API documentation &#8212; color-matcher 0.3.0 documentation</title>
<link rel="stylesheet" href="_static/sphinxdoc.css" type="text/css" />
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<script id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
Expand All @@ -29,7 +29,7 @@ <h3>Navigation</h3>
<li class="right" >
<a href="readme.html" title="color-matcher"
accesskey="P">previous</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">color-matcher 0.2.6 documentation</a> &#187;</li>
<li class="nav-item nav-item-0"><a href="index.html">color-matcher 0.3.0 documentation</a> &#187;</li>
<li class="nav-item nav-item-this"><a href="">API documentation</a></li>
</ul>
</div>
Expand Down Expand Up @@ -121,23 +121,30 @@ <h2>Class hierarchy<a class="headerlink" href="#class-hierarchy" title="Permalin
<dd><p>Initialize self. See help(type(self)) for accurate signature.</p>
</dd></dl>

<dl class="py method">
<dt id="color_matcher.TransferMVGD.analytical_solver">
<code class="sig-name descname">analytical_solver</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)">numpy.ndarray</a><a class="headerlink" href="#color_matcher.TransferMVGD.analytical_solver" title="Permalink to this definition"></a></dt>
<dd><p>An analytical solution to the linear equation system of MVGDs.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p><strong>transfer_mat</strong>: Transfer matrix</p>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p>np.ndarray</p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt id="color_matcher.TransferMVGD.mkl_solver">
<em class="property">static </em><code class="sig-name descname">mkl_solver</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">cov_r</span><span class="p">:</span> <span class="n"><a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)">numpy.ndarray</a></span></em>, <em class="sig-param"><span class="n">cov_z</span><span class="p">:</span> <span class="n"><a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)">numpy.ndarray</a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#color_matcher.TransferMVGD.mkl_solver" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">mkl_solver</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#color_matcher.TransferMVGD.mkl_solver" title="Permalink to this definition"></a></dt>
<dd><p>This function computes the transfer matrix based on the Monge-Kantorovich linearization.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>cov_r</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Covariance matrix of source image</p></li>
<li><p><strong>cov_z</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Covariance matrix of reference image</p></li>
<li><p><strong>transfer_mat</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Transfer matrix</p></li>
</ul>
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p><strong>transfer_mat</strong>: Transfer matrix</p>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><strong>transfer_mat</strong></p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>np.ndarray</p>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p>np.ndarray</p>
</dd>
</dl>
</dd></dl>
Expand All @@ -152,11 +159,11 @@ <h2>Class hierarchy<a class="headerlink" href="#class-hierarchy" title="Permalin
<li><p><strong>src</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Source image that requires transfer</p></li>
<li><p><strong>ref</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Palette image which serves as reference</p></li>
<li><p><strong>fun</strong> – optional argument to pass a transfer function to solve for covariance matrices</p></li>
<li><p><strong>t_r</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Resulting image after the mapping</p></li>
<li><p><strong>res</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Resulting image after the mapping</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><strong>t_r</strong></p>
<dd class="field-even"><p><strong>res</strong></p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>np.ndarray</p>
Expand Down Expand Up @@ -244,7 +251,7 @@ <h3>Navigation</h3>
<li class="right" >
<a href="readme.html" title="color-matcher"
>previous</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">color-matcher 0.2.6 documentation</a> &#187;</li>
<li class="nav-item nav-item-0"><a href="index.html">color-matcher 0.3.0 documentation</a> &#187;</li>
<li class="nav-item nav-item-this"><a href="">API documentation</a></li>
</ul>
</div>
Expand Down
41 changes: 24 additions & 17 deletions docs/build/html/color_matcher.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>color_matcher package &#8212; color-matcher 0.2.6 documentation</title>
<title>color_matcher package &#8212; color-matcher 0.3.0 documentation</title>
<link rel="stylesheet" href="_static/sphinxdoc.css" type="text/css" />
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<script id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
Expand All @@ -25,7 +25,7 @@ <h3>Navigation</h3>
<li class="right" >
<a href="py-modindex.html" title="Python Module Index"
>modules</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">color-matcher 0.2.6 documentation</a> &#187;</li>
<li class="nav-item nav-item-0"><a href="index.html">color-matcher 0.3.0 documentation</a> &#187;</li>
<li class="nav-item nav-item-this"><a href="">color_matcher package</a></li>
</ul>
</div>
Expand Down Expand Up @@ -140,23 +140,30 @@ <h2>Submodules<a class="headerlink" href="#submodules" title="Permalink to this
<dd><p>Initialize self. See help(type(self)) for accurate signature.</p>
</dd></dl>

<dl class="py method">
<dt id="color_matcher.mvgd_matcher.TransferMVGD.analytical_solver">
<code class="sig-name descname">analytical_solver</code><span class="sig-paren">(</span><span class="sig-paren">)</span> &#x2192; <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)">numpy.ndarray</a><a class="headerlink" href="#color_matcher.mvgd_matcher.TransferMVGD.analytical_solver" title="Permalink to this definition"></a></dt>
<dd><p>An analytical solution to the linear equation system of MVGDs.</p>
<dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p><strong>transfer_mat</strong>: Transfer matrix</p>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p>np.ndarray</p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt id="color_matcher.mvgd_matcher.TransferMVGD.mkl_solver">
<em class="property">static </em><code class="sig-name descname">mkl_solver</code><span class="sig-paren">(</span><em class="sig-param"><span class="n">cov_r</span><span class="p">:</span> <span class="n"><a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)">numpy.ndarray</a></span></em>, <em class="sig-param"><span class="n">cov_z</span><span class="p">:</span> <span class="n"><a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)">numpy.ndarray</a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#color_matcher.mvgd_matcher.TransferMVGD.mkl_solver" title="Permalink to this definition"></a></dt>
<code class="sig-name descname">mkl_solver</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#color_matcher.mvgd_matcher.TransferMVGD.mkl_solver" title="Permalink to this definition"></a></dt>
<dd><p>This function computes the transfer matrix based on the Monge-Kantorovich linearization.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>cov_r</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Covariance matrix of source image</p></li>
<li><p><strong>cov_z</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Covariance matrix of reference image</p></li>
<li><p><strong>transfer_mat</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Transfer matrix</p></li>
</ul>
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p><strong>transfer_mat</strong>: Transfer matrix</p>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><strong>transfer_mat</strong></p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>np.ndarray</p>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p>np.ndarray</p>
</dd>
</dl>
</dd></dl>
Expand All @@ -171,11 +178,11 @@ <h2>Submodules<a class="headerlink" href="#submodules" title="Permalink to this
<li><p><strong>src</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Source image that requires transfer</p></li>
<li><p><strong>ref</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Palette image which serves as reference</p></li>
<li><p><strong>fun</strong> – optional argument to pass a transfer function to solve for covariance matrices</p></li>
<li><p><strong>t_r</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Resulting image after the mapping</p></li>
<li><p><strong>res</strong> (<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v1.19)"><code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code></a>) – Resulting image after the mapping</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p><strong>t_r</strong></p>
<dd class="field-even"><p><strong>res</strong></p>
</dd>
<dt class="field-odd">Return type</dt>
<dd class="field-odd"><p>np.ndarray</p>
Expand Down Expand Up @@ -314,7 +321,7 @@ <h3>Navigation</h3>
<li class="right" >
<a href="py-modindex.html" title="Python Module Index"
>modules</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">color-matcher 0.2.6 documentation</a> &#187;</li>
<li class="nav-item nav-item-0"><a href="index.html">color-matcher 0.3.0 documentation</a> &#187;</li>
<li class="nav-item nav-item-this"><a href="">color_matcher package</a></li>
</ul>
</div>
Expand Down
Loading

0 comments on commit 12e3912

Please sign in to comment.