Skip to content

Commit

Permalink
fix bug of oposite initialization of code and dictionary and added te…
Browse files Browse the repository at this point in the history
…st case #6
  • Loading branch information
tksmd committed Mar 28, 2018
1 parent db3c41f commit 3f2840b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 9 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,17 @@ python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```

### Testing

You can run all test cases just like this

```
python -m unittest tests/test_*.py
```

Or run specific test case as follows

```
python -m unittest test_decomposition_ksvd.TestKSVD.test_ksvd
```
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
numpy
scipy
scikit-learn

scikit-learn>=0.19.0
18 changes: 11 additions & 7 deletions spmimage/decomposition/ksvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,20 @@ def _ksvd(Y: np.ndarray, n_components: int, k0: int, tol: float, max_iter: int,
set to True.
"""

if code_init is None:
A = Y[:, :n_components]
if dict_init is None:
A = Y[:n_components, :]
else:
A = code_init
A = dict_init
A = np.dot(A, np.diag(1. / np.sqrt(np.diag(np.dot(A.T, A)))))

if dict_init is None:
X = np.zeros((A.shape[1], Y.shape[1]))
if code_init is None:
X = np.zeros((Y.shape[0], A.shape[1]))
else:
X = dict_init
X = code_init

Y = Y.T
X = X.T
A = A.T

errors = [np.linalg.norm(Y - A.dot(X), 'fro')]
k = -1
Expand All @@ -78,7 +82,7 @@ def _ksvd(Y: np.ndarray, n_components: int, k0: int, tol: float, max_iter: int,
if np.abs(errors[-1] - errors[-2]) < tol:
break

return A, X, errors, k + 1
return A.T, X, errors, k + 1


class KSVD(BaseEstimator, SparseCodingMixin):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_decomposition_ksvd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from unittest import TestCase
from typing import Tuple

from spmimage.decomposition import KSVD

import numpy as np


class TestKSVD(TestCase):

def generate_input(self, dict_size: Tuple[int, int], k0: int, n_samples: int) -> Tuple[np.ndarray, np.ndarray]:
# random dictionary base
A0 = np.random.randn(*dict_size)
X = np.zeros((dict_size[0], n_samples))
for i in range(n_samples):
# select k0 components from dictionary
X[:, i] = np.dot(A0[:, np.random.permutation(range(dict_size[1]))[:k0]], np.random.randn(k0))
return A0, X.T

def test_ksvd(self):
np.random.seed(0)
k0 = 4
n_samples = 512
dict_size = (24, 32)
max_iter = 100
A0, X = self.generate_input(dict_size, k0, n_samples)
model = KSVD(n_components=dict_size[1], k0=k0, max_iter=max_iter)
model.fit(X)

norm = np.linalg.norm(model.components_ - A0.T, ord='fro')

self.assertTrue(model.error_[-1] < 75)
self.assertTrue(norm < 50)
self.assertTrue(model.n_iter_ <= max_iter)

code = model.transform(X)
reconstructed = np.dot(code, model.components_)
reconstruct_error = np.linalg.norm(reconstructed - X, ord='fro')
print(reconstruct_error)

0 comments on commit 3f2840b

Please sign in to comment.