Skip to content

Commit

Permalink
lsa
Browse files Browse the repository at this point in the history
  • Loading branch information
SmirkCao committed Jun 12, 2019
1 parent 2a87801 commit e485f78
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
21 changes: 20 additions & 1 deletion CH17/lsa.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
#! /usr/bin/env python

# Project: Lihang
# Filename: lsa
# Date: 6/03/19
# Author: 😏 <smirk dot cao at gmail dot com>
# 截断奇异值分解用在count/tf-idf矩阵的时候,叫做潜在语义分析。
import numpy as np


class LSA(object):
def __init__(self, n_components):
self.n_components = n_components
self.components = None
self.singular_values = None
self.explained_variance_ratio = None

def fit(self, x):
pass
u, s, vh = np.linalg.svd(x, full_matrices=False)
max_abs_raws = np.argmax(np.abs(vh), axis=1)
signs = np.sign(vh[range(vh.shape[0]), max_abs_raws])
u *= signs
vh *= signs[:, np.newaxis]
k = self.n_components
#
self.components = vh[:k]
self.singular_values = s[:k]
#
x_transformed = u*s
self.explained_variance = np.var(x_transformed, axis=0)
self.explained_variance_ratio = (self.explained_variance/self.explained_variance.sum())[:k]
self.explained_variance = self.explained_variance[:k]
return x_transformed

31 changes: 31 additions & 0 deletions CH17/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,34 @@ def test_lsa_1701(self):
print("singular_values\n", lsa1.singular_values_)
print("components\n", lsa1.components_)
print("rst\n", rst)

def test_lsa(self):
x = np.array([[0., 0., 1., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 1.],
[0., 1., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 1.],
[1., 0., 0., 0., 0., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 1.],
[0., 0., 0., 0., 0., 2., 0., 0., 1.],
[1., 0., 1., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 0., 0.]])

svd = lsa_sklearn(n_components=3, n_iter=7, random_state=42)
svd.fit(x)
print("\n")
print("lsa sklearn")
print("components_\n", svd.components_)
print(svd.singular_values_)
print(svd.explained_variance_)
print(svd.explained_variance_ratio_)
print(svd.explained_variance_ratio_.sum())

svd_1 = lsa_test(n_components=3)
svd_1.fit(x)
print("lsa test")
print(svd_1.components)
print(svd_1.singular_values)
print(svd_1.explained_variance)
print(svd_1.explained_variance_ratio)

0 comments on commit e485f78

Please sign in to comment.