Skip to content

Commit

Permalink
Update bm25 matching in QA module.
Browse files Browse the repository at this point in the history
  • Loading branch information
zake7749 committed Dec 6, 2016
1 parent 624813b commit c72c135
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 4 deletions.
151 changes: 151 additions & 0 deletions Chatbot/QuestionAnswering/Matcher/bm25Matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import math

from .matcher import Matcher

class bestMatchingMatcher(Matcher):

"""
基於 bm25 算法取得最佳關聯短語
"""

def __init__(self, segLib="Taiba", removeStopWords=False):
super().__init__(segLib)

self.cleanStopWords = removeStopWords
self.D = 0 # 句子總數

self.wordset = set() # Corpus 中所有詞的集合
self.words_location_record = dict() # 紀錄該詞 (key) 出現在哪幾個句子(id)
self.words_idf = dict() # 紀錄每個詞的 idf 值

self.f = []
self.df = {}
self.idf = {}
self.k1 = 1.5
self.b = 0.75

if removeStopWords:
self.loadStopWords("data/stopwords/chinese_sw.txt")
self.loadStopWords("data/stopwords/specialMarks.txt")

def initialize(self,ngram=1):

assert len(self.titles) > 0, "請先載入短語表"

self.TitlesSegmentation() # 將 self.titles 斷詞為 self.segTitles
#self.calculateIDF() # 依照斷詞後結果, 計算每個詞的 idf value
self.initBM25()

"""NEED MORE DISCUSSION
#for n in range(0,ngram):
# self.addNgram(n)
"""

def initBM25(self):

print("BM25模塊初始化中")

self.D = len(self.segTitles)
self.avgdl = sum([len(title) + 0.0 for title in self.segTitles]) / self.D

for seg_title in self.segTitles:
tmp = {}
for word in seg_title:
if not word in tmp:
tmp[word] = 0
tmp[word] += 1
self.f.append(tmp)
for k, v in tmp.items():
if k not in self.df:
self.df[k] = 0
self.df[k] += 1
for k, v in self.df.items():
self.idf[k] = math.log(self.D-v+0.5)-math.log(v+0.5)

print("BM25模塊初始化完成")

def sim(self, doc, index):
score = 0
for word in doc:
if word not in self.f[index]:
continue
d = len(self.segTitles[index])
score += (self.idf[word]*self.f[index][word]*(self.k1+1)
/ (self.f[index][word]+self.k1*(1-self.b+self.b*d
/ self.avgdl)))
return score

def calculateIDF(self):

# 構建詞集與紀錄詞出現位置的字典
if len(self.wordset) == 0:
self.buildWordSet()
if len(self.words_location_record) == 0:
self.buildWordLocationRecord()

# 計算 idf
for word in self.wordset:
self.words_idf[word] = math.log2((self.D + .5)/(self.words_location_record[word] + .5))


def buildWordLocationRecord(self):
"""
建構詞與詞出現位置(句子id)的字典
"""
for idx,seg_title in enumerate(self.segTitles):
for word in seg_title:
if self.words_location_record[word] is None:
self.words_location_record[word] = set()
self.words_location_record[word].add(idx)

def buildWordSet(self):
"""
建立 Corpus 詞集
"""
for seg_title in self.segTitles:
for word in seg_title:
self.wordset.add(word)

def addNgram(self,n):
"""
擴充 self.seg_titles 為 n-gram
"""
idx = 0

for seg_list in self.segTitles:
ngram = self.generateNgram(n,self.titles[idx])
seg_list = seg_list + ngram
idx += 1

def generateNgram(self,n,sentence):
return [sentence[i:i+n] for i in range(0,len(sentence)-1)]


def joinTitles(self):
self.segTitles = ["".join(title) for title in self.segTitles]

def match(self, query):
"""
讀入使用者 query,若語料庫中存在類似的句子,便回傳該句子與標號
Args:
- query: 使用者欲查詢的語句
"""

seg_query = self.wordSegmentation(query)
max = -1
target = ''
target_idx = -1

for index in range(self.D):
score = self.sim(seg_query, index)
if score > max:
target_idx = index
max = score

# normalization
max = max / self.sim(self.segTitles[target_idx],target_idx)
target = ''.join(self.segTitles[target_idx])
self.similarity = max * 100 #百分制

return target,target_idx
16 changes: 15 additions & 1 deletion Chatbot/QuestionAnswering/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .responsesEvaluate import Evaluator
from .Matcher.fuzzyMatcher import FuzzyMatcher
from .Matcher.wordWeightMatcher import WordWeightMatcher
from .Matcher.bm25Matcher import bestMatchingMatcher
from .Matcher.matcher import Matcher

def getMatcher(matcherType,removeStopWords=False):
Expand All @@ -24,6 +25,8 @@ def getMatcher(matcherType,removeStopWords=False):
return woreWeightMatch()
elif matcherType == "Fuzzy":
return fuzzyMatch(removeStopWords)
elif matcherType == "bm25":
return bm25()
elif matcherType == "Vectorize":
pass #TODO
elif matcherType == "DeepLearning":
Expand All @@ -48,7 +51,7 @@ def matcherTesting(matcherType,removeStopWords=False):
#randomId = random.randrange(0,len(res[targetId]))

evaluator = Evaluator()
candiates = evaluator.getBestResponse(responses=res[targetId],topk=5,debugMode=True)
candiates = evaluator.getBestResponse(responses=res[targetId],topk=5,debugMode=False)
print("以下是相似度前 5 高的回應")
for candiate in candiates:
print("%s %f" % (candiate[0],candiate[1]))
Expand Down Expand Up @@ -81,3 +84,14 @@ def fuzzyMatch(cleansw=False):
#fuzzyMatcher.loadStopWords(path="data/stopwords/chinese_sw.txt")
#fuzzyMatcher.loadStopWords(path="data/stopwords/ptt_words.txt")
#fuzzyMatcher.loadStopWords(path="data/stopwords/specialMarks.txt")

def bm25():

cur_dir = os.getcwd()
os.chdir(os.path.dirname(__file__))
bm25Matcher = bestMatchingMatcher()
bm25Matcher.loadTitles(path="data/Titles.txt")
bm25Matcher.initialize()
os.chdir(cur_dir)

return bm25Matcher
2 changes: 1 addition & 1 deletion Chatbot/QuestionAnswering/qaBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self):
self.general_questions = []
self.path = os.path.dirname(__file__)

self.matcher = getMatcher(matcherType="Fuzzy")
self.matcher = getMatcher(matcherType="bm25")
self.evaluator = Evaluator()
self.moduleTest()

Expand Down
2 changes: 1 addition & 1 deletion Chatbot/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, name="MianBot"):
self.custom_rulebase.model = self.console.rb.model # pass word2vec model

# For QA
self.github_qa_unupdated = True
self.github_qa_unupdated = False
if not self.github_qa_unupdated:
self.answerer = qa.Answerer()

Expand Down
1 change: 0 additions & 1 deletion Chatbot/task_modules/module_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def get_handler(self, domain):
handler = medicine.MedicalListener(self.console)
elif domain=="天氣":
handler = Weather(self.console)

#elif domain=="股票":
# handler = Stock(self.console)
#elif domain=="購買":
Expand Down

0 comments on commit c72c135

Please sign in to comment.