Skip to content

Commit

Permalink
pca
Browse files Browse the repository at this point in the history
  • Loading branch information
SmirkCao committed May 31, 2019
1 parent a7b09d7 commit 7fdbacc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
25 changes: 20 additions & 5 deletions CH16/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,30 @@ def __init__(self, n_components=2):
self.n_components_ = n_components
self.explained_variance_ratio_ = None
self.singular_values_ = None
self.u = None
self.vh = None
self.components_ = None

def __str__(self,):
rst = "PCA algorithms:\n"
rst += "n_components: " + str(self.n_components_)
return rst

def fit(self, x):
# check n_components and min(n_samples, n_features)
pass

def fit_transform(x):
return x
n = x.shape[0]
assert n > 1
assert (np.mean(x, axis=1) == np.zeros(n)).all()
x_ = x.T/np.sqrt(n-1)
u, s, vh = np.linalg.svd(x_)
self.vh = vh
self.u = u
self.singular_values_ = s
self.explained_variance_ratio_ = s**2/np.sum(s**2)
print(self.u)
print(self.vh)

def fit_transform(self, x):
self.fit(x)
self.components_ = np.dot(self.vh, x)
return self.components_
18 changes: 16 additions & 2 deletions CH16/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,34 @@ def test_ex1601(self):
# plt.show()

def test_pca(self):
"""
PCA分析
"""
print("\n")
# raw data
x = np.array([[2, 3, 3, 4, 5, 7],
[2, 4, 5, 5, 6, 8]])
x = x-np.mean(x, axis=1).reshape(-1, 1)
print(x)
assert (np.mean(x, axis=1) == np.zeros(2)).all()

# for sklearn x.shape == (n_samples, n_features)
pca_sklearn = skpca(n_components=2)
pca_sklearn.fit(x.T)

print("\n")
print(40*"*"+"sklearn_pca"+40*"*")
print(pca_sklearn.explained_variance_ratio_)
print(pca_sklearn.singular_values_)
print(pca_sklearn.explained_variance_ratio_)
print(pca_sklearn.fit_transform(x.T).T)

print(40*"*"+"smirk_pca"+40*"*")
pca_test = smirkpca(n_components=2)
print(pca_test)
rst = pca_test.fit_transform(x)
print(pca_test.singular_values_)
print(pca_test.explained_variance_ratio_)
print(rst)

def test_pca_get_fig(self):
pass

0 comments on commit 7fdbacc

Please sign in to comment.