-
Notifications
You must be signed in to change notification settings - Fork 1
/
lda.py
46 lines (37 loc) · 1.55 KB
/
lda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from gensim.models import LdaModel
from gensim.models import LdaMulticore
from gensim.test.utils import common_texts
from gensim.corpora.dictionary import Dictionary
import numpy as np
import os
def cluster_questions(topic_num, res_path, q_path='datasets\DialogQA\Qall.txt', a_path='datasets\DialogQA\Aall.txt'):
with open(a_path, 'r', encoding='utf-8') as f:
common_texts = [text.split() for text in f.readlines()]
with open(q_path, 'r', encoding='utf-8') as f:
questions = [text for text in f.readlines()]
common_dictionary = Dictionary(common_texts)
common_corpus = [common_dictionary.doc2bow(text) for text in common_texts]
lda = LdaModel(common_corpus, num_topics=topic_num)
questions_clusterd = [[] for i in range(topic_num)]
print('Questions : ', len(questions))
perp = lda.log_perplexity(common_corpus)
for i, q in enumerate(questions):
other_corpus = [common_dictionary.doc2bow(common_texts[i])]
vector = lda[other_corpus]
# print(vector[0])
max_prob = 0
for (idx, prob) in vector[0]:
# print(idx)
if prob > max_prob:
topic = idx
max_prob = prob
questions_clusterd[topic].append(q)
# print(topic)
if (not os._exists(res_path)):
os.makedirs(res_path)
for top in range(topic_num):
with open(res_path + str(top) + '.txt', 'w', encoding='utf-8') as f:
for quest in questions_clusterd[top]:
f.write(quest)
# f.write('\n')
return perp