Skip to content

Commit 6f93c3f

Browse files
committed
add setup.py for local install
1 parent 72818a8 commit 6f93c3f

17 files changed

+180
-8
lines changed

ptm/__init__.py

Whitespace-only changes.
File renamed without changes.

ctm.py renamed to ptm/ctm.py

File renamed without changes.

diln.py renamed to ptm/diln.py

File renamed without changes.

ptm/formatted_logger.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
import os
3+
import time
4+
5+
default_log_path = './logs'
6+
7+
def formatted_logger(label, level=None, format=None, date_format=None, file_path=None):
8+
log = logging.getLogger(label)
9+
if level is None:
10+
level = logging.INFO
11+
elif level.lower() == 'debug':
12+
level = logging.DEBUG
13+
elif level.lower() == 'info':
14+
level = logging.INFO
15+
elif level.lower() == 'warn':
16+
level = logging.WARN
17+
elif level.lower() == 'error':
18+
level = logging.ERROR
19+
elif level.lower() == 'critical':
20+
level = logging.CRITICAL
21+
log.setLevel(level)
22+
23+
if format is None:
24+
format = '%(asctime)s %(levelname)s:%(name)s:%(message)s'
25+
if date_format is None:
26+
date_format = '%Y-%m-%d %H:%M:%S'
27+
if file_path is None:
28+
if not os.path.exists(default_log_path):
29+
os.makedirs(default_log_path)
30+
file_path = '%s/%s.%s.log.txt' % (default_log_path, label, time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()))
31+
32+
formatter = logging.Formatter(format, date_format)
33+
stream_handler = logging.StreamHandler()
34+
stream_handler.setFormatter(formatter)
35+
file_handler = logging.FileHandler(file_path)
36+
file_handler.setFormatter(formatter)
37+
log.addHandler(file_handler)
38+
log.addHandler(stream_handler)
39+
return log
File renamed without changes.

hdsp.py renamed to ptm/hdsp.py

File renamed without changes.

hmm_lda.py renamed to ptm/hmm_lda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import numpy as np
22

33
class HMM_LDA:
4+
""" implementation of HMM-LDA proposed by Griffiths et al. (2004)
5+
Original reference : Integrating topics and syntax, Griffiths, Thomas L and Steyvers, Mark and Blei, David M and Tenenbaum, Joshua B, NIPS 2004
6+
"""
7+
48
def __init__(self, num_class, num_topic, num_voca, docs):
59
self.C = num_class
610
self.K = num_topic
File renamed without changes.

lda_vb.py renamed to ptm/lda_vb.py

File renamed without changes.
File renamed without changes.

rtm.py renamed to ptm/rtm.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
import numpy as np
1+
import numpy as np
2+
import utils
23
from scipy.special import gammaln, psi
4+
from formatted_logger import formatted_logger
35

4-
eps = 1e-10
6+
eps = 1e-20
7+
8+
log = formatted_logger('RTM', 'info')
59

610
class rtm:
711
""" implementation of relational topic model by Chang and Blei (2009)
@@ -13,7 +17,7 @@ def __init__(self, num_topic, num_doc, num_voca, doc_ids, doc_cnt, doc_links, rh
1317
self.K = num_topic
1418
self.V = num_voca
1519

16-
self.alpha = 1.
20+
self.alpha = .1
1721

1822
self.gamma = np.random.gamma(100., 1./100, [self.D, self.K])
1923
self.beta = np.random.dirichlet([5]*self.V, self.K)
@@ -35,19 +39,21 @@ def __init__(self, num_topic, num_doc, num_voca, doc_ids, doc_cnt, doc_links, rh
3539
self.doc_links = doc_links
3640
self.rho = rho #regularization parameter
3741

42+
log.info('Initialize RTM: num_voca:%d, num_topic:%d, num_doc:%d' % (self.V,self.K,self.D))
43+
3844
def posterior_inference(self, max_iter):
3945
for iter in xrange(max_iter):
4046
self.variation_update()
4147
self.parameter_estimation()
42-
print self.compute_elbo()
48+
log.info('%d iter: ELBO = %.3f' % (iter, self.compute_elbo()))
4349

4450
def compute_elbo(self):
4551
""" compute evidence lower bound for trained model
4652
"""
4753
elbo = 0
4854

4955
e_log_theta = psi(self.gamma) - psi(np.sum(self.gamma, 1))[:,np.newaxis] # D x K
50-
log_beta = np.log(self.beta)
56+
log_beta = np.log(self.beta+eps)
5157

5258
for di in xrange(self.D):
5359
words = self.doc_ids[di]
@@ -62,7 +68,7 @@ def compute_elbo(self):
6268
elbo += - np.sum(cnt * self.phi[di] * np.log(self.phi[di])) # - E_q[log q(z|phi)]
6369

6470
for adi in self.doc_links[di]:
65-
elbo += np.dot(self.eta, self.pi[di]*self.pi[adi]) # E_q[log p(y_{d1,d2}|z_{d1},z_{d2},\eta,\nu)]
71+
elbo += np.dot(self.eta, self.pi[di]*self.pi[adi]) + self.nu # E_q[log p(y_{d1,d2}|z_{d1},z_{d2},\eta,\nu)]
6672

6773
return elbo
6874

@@ -77,7 +83,7 @@ def variation_update(self):
7783
cnt = self.doc_cnt[di]
7884
doc_len = np.sum(cnt)
7985

80-
new_phi = np.log(self.beta[:,words]) + e_log_theta[di,:][:,np.newaxis]
86+
new_phi = np.log(self.beta[:,words]+eps) + e_log_theta[di,:][:,np.newaxis]
8187

8288
gradient = np.zeros(self.K)
8389
for adi in self.doc_links[di]:
@@ -114,6 +120,18 @@ def parameter_estimation(self):
114120
self.nu = np.log(num_links-np.sum(pi_sum)) - np.log(self.rho*(self.K-1)/self.K + num_links - np.sum(pi_sum))
115121
self.eta = np.log(pi_sum) - np.log(pi_sum + self.rho * pi_alpha) - self.nu
116122

123+
def save_model(self, output_directory, vocab=None):
124+
import os
125+
if not os.path.exists(output_directory):
126+
os.mkdir(output_directory)
127+
128+
np.savetxt(output_directory+'/eta.txt', self.eta, delimiter='\t')
129+
with open(output_directory+'/nu.txt', 'w') as f:
130+
f.write('%f\n'%self.nu)
131+
np.savetxt(output_directory+'/beta.txt', self.beta, delimiter='\t')
132+
np.savetxt(output_directory+'/gamma.txt',self.gamma,delimiter='\t')
133+
if vocab:
134+
utils.write_top_words(self.beta, vocab, output_directory+'/top_words.csv')
117135

118136
def main():
119137
rho = 1
File renamed without changes.
File renamed without changes.

utils.py renamed to ptm/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,12 @@ def convert_wrdcnt_wrdlist(corpus_ids, corpus_cnt):
143143
corpus.append(doc)
144144
return corpus
145145

146+
147+
def write_top_words(topic_word_matrix, vocab, filepath, top_words = 20, delimiter=',', newline='\n'):
148+
with open(filepath, 'w') as f:
149+
for ti in xrange(topic_word_matrix.shape[0]):
150+
top_words = vocab[topic_word_matrix[:,ti].argsort()[::-1][:top_words]]
151+
f.write( '%d' % (ti) )
152+
for word in top_words:
153+
f.write(delimiter + word)
154+
f.write(newline)

whdsp.py renamed to ptm/whdsp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def __init__(self, vocab, word_ids, word_cnt, num_topics, labels, label_names =
460460

461461
if type(vocab) == list:
462462
self.vocab = np.array(vocab)
463-
else
463+
else:
464464
self.vocab = vocab
465465

466466
if type(word_ids[0]) != np.ndarray:

setup.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from setuptools import setup, find_packages # Always prefer setuptools over distutils
2+
from codecs import open # To use a consistent encoding
3+
from os import path
4+
5+
here = path.abspath(path.dirname(__file__))
6+
7+
# Get the long description from the relevant file
8+
# with open(path.join(here, 'DESCRIPTION.rst'), encoding='utf-8') as f:
9+
# long_description = f.read()
10+
long_description = open('README.md').read()
11+
12+
setup(
13+
name='ptm',
14+
15+
# Versions should comply with PEP440. For a discussion on single-sourcing
16+
# the version across setup.py and the project code, see
17+
# https://packaging.python.org/en/latest/development.html#single-sourcing-the-version
18+
version='0.0.1',
19+
20+
description='Probabilistic topic model',
21+
long_description=long_description,
22+
23+
# The project's main homepage.
24+
url='https://github.com/arongdari/python-topic-model/',
25+
26+
# Author details
27+
author='Dongwoo Kim',
28+
author_email='arongdari@gmail.com',
29+
30+
# Choose your license
31+
license='Apache License 2.0',
32+
33+
# See https://pypi.python.org/pypi?%3Aaction=list_classifiers
34+
classifiers=[
35+
# How mature is this project? Common values are
36+
# 3 - Alpha
37+
# 4 - Beta
38+
# 5 - Production/Stable
39+
'Development Status :: 3 - Alpha',
40+
41+
# Indicate who your project is intended for
42+
'Intended Audience :: Developers',
43+
'Topic :: Software Development :: Build Tools',
44+
45+
# Pick your license as you wish (should match "license" above)
46+
'License :: OSI Approved :: Apache Software License',
47+
48+
# Specify the Python versions you support here. In particular, ensure
49+
# that you indicate whether you support Python 2, Python 3 or both.
50+
# 'Programming Language :: Python :: 2',
51+
# 'Programming Language :: Python :: 2.6',
52+
'Programming Language :: Python :: 2.7',
53+
# 'Programming Language :: Python :: 3',
54+
# 'Programming Language :: Python :: 3.2',
55+
# 'Programming Language :: Python :: 3.3',
56+
# 'Programming Language :: Python :: 3.4',
57+
],
58+
59+
# What does your project relate to?
60+
keywords='topic model lda',
61+
62+
# You can just specify the packages manually here if your project is
63+
# simple. Or you can use find_packages().
64+
# packages=find_packages(exclude=['contrib', 'docs', 'tests*']),
65+
packages=find_packages(),
66+
67+
# List run-time dependencies here. These will be installed by pip when your
68+
# project is installed. For an analysis of "install_requires" vs pip's
69+
# requirements files see:
70+
# https://packaging.python.org/en/latest/technical.html#install-requires-vs-requirements-files
71+
install_requires=['numpy', 'scipy', ],
72+
73+
# List additional groups of dependencies here (e.g. development dependencies).
74+
# You can install these using the following syntax, for example:
75+
# $ pip install -e .[dev,test]
76+
extras_require = {
77+
'dev': [],#'check-manifest'],
78+
'test': [],#'coverage'],
79+
},
80+
81+
# If there are data files included in your packages that need to be
82+
# installed, specify them here. If using Python 2.6 or less, then these
83+
# have to be included in MANIFEST.in as well.
84+
package_data={
85+
# 'sample': ['package_data.dat'],
86+
},
87+
88+
# Although 'package_data' is the preferred approach, in some case you may
89+
# need to place data files outside of your packages.
90+
# see http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files
91+
# In this case, 'data_file' will be installed into '<sys.prefix>/my_data'
92+
data_files=[],#('my_data', ['data/data_file'])],
93+
94+
# To provide executable scripts, use entry points in preference to the
95+
# "scripts" keyword. Entry points provide cross-platform support and allow
96+
# pip to create the appropriate form of executable for the target platform.
97+
entry_points={
98+
# 'console_scripts': [
99+
# 'sample=sample:main',
100+
# ],
101+
},
102+
)

0 commit comments

Comments
 (0)