Skip to content

Commit

Permalink
lua: 简化和修复 cold_word_drop 逻辑 (#923)
Browse files Browse the repository at this point in the history
* update cold_word_drop
* Delete lua/cold_word_drop/debugtool.lua
* Delete lua/cold_word_drop/turndown_freq_words.lua
  • Loading branch information
boomker authored Jun 23, 2024
1 parent 0b59306 commit b2cb11c
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 413 deletions.
92 changes: 0 additions & 92 deletions lua/cold_word_drop/debugtool.lua

This file was deleted.

5 changes: 2 additions & 3 deletions lua/cold_word_drop/drop_words.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
local drop_words =
{ "示~例~",
}
return drop_words
{ "示~例~", "肏女人", }
return drop_words
104 changes: 55 additions & 49 deletions lua/cold_word_drop/filter.lua
Original file line number Diff line number Diff line change
@@ -1,54 +1,60 @@
local drop_list = require("cold_word_drop.drop_words")
local hide_list = require("cold_word_drop.hide_words")
local turndown_freq_list = require("cold_word_drop.turndown_freq_words")
local filter = {}

local function filter(input, env)
local idx = 3 -- 降频的词条放到第三个后面, 即第四位, 可在 yaml 里配置
local i = 1
local cands = {}
local context = env.engine.context
local preedit_code = context.input
function filter.init(env)
local engine = env.engine
local config = engine.schema.config
env.word_reduce_idx = config:get_int("cold_word_reduce/idx") or 4
env.drop_words = require("cold_word_drop.drop_words") or {}
env.hide_words = require("cold_word_drop.hide_words") or {}
env.reduce_freq_words = require("cold_word_drop.reduce_freq_words") or {}
end

function filter.func(input, env)
local cands = {}
local context = env.engine.context
local preedit_str = context.input:gsub(" ", "")
local drop_words = env.drop_words
local hide_words = env.hide_words
local word_reduce_idx = env.word_reduce_idx
local reduce_freq_words = env.reduce_freq_words
for cand in input:iter() do
local cand_text = cand.text:gsub(" ", "")
local preedit_code = cand.preedit:gsub(" ", "") or preedit_str

local reduce_freq_list = reduce_freq_words[cand_text] or {}
if word_reduce_idx > 1 then
-- 前三个 候选项排除 要调整词频的词条, 要删的(实际假性删词, 彻底隐藏罢了) 和要隐藏的词条
if reduce_freq_list and table.find_index(reduce_freq_list, preedit_code) then
table.insert(cands, cand)
elseif
not (
table.find_index(drop_words, cand_text)
or (hide_words[cand_text] and table.find_index(hide_words[cand_text], preedit_code))

)
then
yield(cand)
word_reduce_idx = word_reduce_idx - 1
end
else
if
not (
table.find_index(drop_words, cand_text)
or (hide_words[cand_text] and table.find_index(hide_words[cand_text], preedit_code))
)
then
table.insert(cands, cand)
end
end

if #cands >= 80 then
break
end
end

for cand in input:iter() do
local cpreedit_code = string.gsub(cand.preedit, ' ', '')
if (i <= idx) then
local tfl = turndown_freq_list[cand.text] or nil
-- 前三个 候选项排除 要调整词频的词条, 要删的(实际假性删词, 彻底隐藏罢了) 和要隐藏的词条
if not
((tfl and table.find_index(tfl, cpreedit_code)) or
table.find_index(drop_list, cand.text) or
(hide_list[cand.text] and table.find_index(hide_list[cand.text], cpreedit_code))
)
then
i = i + 1
---@diagnostic disable-next-line: undefined-global
yield(cand)
else
table.insert(cands, cand)
end
else
table.insert(cands, cand)
end
if (#cands > 50) then
break
end
end
for _, cand in ipairs(cands) do
local cpreedit_code = string.gsub(cand.preedit, ' ', '')
if not
-- 要删的 和要隐藏的词条不显示
(
table.find_index(drop_list, cand.text) or
(hide_list[cand.text] and table.find_index(hide_list[cand.text], cpreedit_code))
)
then
---@diagnostic disable-next-line: undefined-global
yield(cand)
end
end
for cand in input:iter() do
yield(cand)
end
for _, cand in ipairs(cands) do
yield(cand)
end
end

return filter
5 changes: 3 additions & 2 deletions lua/cold_word_drop/hide_words.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local hide_words =
{ ["示~例~"] = { "shil", "shili", },
{ ["示~例~"] = { "shil", "shili", },
["么特瑞"] = { "meter", },
}
return hide_words
return hide_words
48 changes: 48 additions & 0 deletions lua/cold_word_drop/logger.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
-- runLog.lua
-- Copyright (C) 2023 yaoyuan.dou <douyaoyuan@126.com>

local M = {}
local dbgFlg = true

--设置 dbg 开关
M.setDbg = function(flg)
dbgFlg = flg

print('runLog dbgFlg is ' .. tostring(dbgFlg))
end

local current_path = string.sub(debug.getinfo(1).source, 2, string.len("/runLog.lua") * -1)
M.logDoc = current_path .. 'runLog.txt'

M.writeLog = function(logStr, newLineFlg)
logStr = logStr or "nothing"

if not newLineFlg then newLineFlg = true end

local f = io.open(M.logDoc, 'a')
if f then
local timeStamp = os.date("%Y/%m/%d %H:%M:%S")
f:write(timeStamp .. '[' .. _VERSION .. ']' .. '\t' .. logStr .. '\n')
f:close()
end
end

--===========================test========================
M.test = function(printPrefix)
if nil == printPrefix then
printPrefix = ' '
end
if dbgFlg then
M.writeLog('this is a test string on new line', true)
M.writeLog('this is a test string appending the last line', false)
M.writeLog('runLogDoc is: ' .. M.logDoc, true)
end
end

function M.init(...)
--如果有需要初始化的动作,可以在这里运行
end

M.init()

return M
Loading

0 comments on commit b2cb11c

Please sign in to comment.