diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..f8de632 --- /dev/null +++ b/build.sh @@ -0,0 +1 @@ +cd build && make && cd .. diff --git a/ctc_log.lua b/ctc_log.lua index 258e85c..536365e 100644 --- a/ctc_log.lua +++ b/ctc_log.lua @@ -43,7 +43,7 @@ function ctc.__getFilledTargetFromString(target) end function ctc.__getFilledTarget(target) - local filled = torch.zeros((#target)[1] * 2 + 1) + local filled = torch.zeros(#target * 2 + 1) for i = 1, (#filled)[1] do if i % 2 == 0 then filled[i] = target[i / 2] @@ -205,12 +205,20 @@ function ctc.getCTCCostAndGrad(outputTable, target) targetClasses = ctc.__getFilledTarget(target) + + -- print(targetClasses) + targetMatrix = ctc.__getOnehotMatrix(targetClasses, class_num) outputTable = ctc.__toMatrix(outputTable, class_num) + outputTable = outputTable:cmax(1e-4) + local total = outputTable:sum(2):expand(outputTable:size()[1], outputTable:size()[2]) + outputTable = torch.cdiv(outputTable, total) + + -- print(outputTable) for i = 1, (#outputTable)[1] do @@ -221,6 +229,8 @@ function ctc.getCTCCostAndGrad(outputTable, target) + + -- get aligned_table -- outputTable: Tx(cls+1) -- target: L'x(cls+1) --> targetT : (cls+1)xL' @@ -240,12 +250,19 @@ function ctc.getCTCCostAndGrad(outputTable, target) local bvs= ctc.__getBackwardVariable(outputTable, alignedTable, targetMatrix) local fb = fvs + bvs - + -- calculate gradient matrix (Tx(cls+1)) local grad = ctc.__getGrad(fb, pzx, class_num, outputTable, targetClasses) + --[[ + print("=========FVS=========") + print(fvs:t()) + print("=========BVS=========") + print(bvs:t()) + print("=========GRAD=========") + print(grad) + ]] - -- print(grad) grad = nn.SplitTable(1):forward(grad) return -pzx, grad diff --git a/loader.lua b/loader.lua index e5cb0b4..6fb5dd6 100644 --- a/loader.lua +++ b/loader.lua @@ -44,9 +44,7 @@ function Loader.__getNormalizedImage(src) ones = torch.ones(h, w) im = ones - im - normalizer.normalize(im, output) - return output end @@ -55,7 +53,7 @@ function Loader:load(file) local f = assert(io.open(file, "r")) for line in f:lines() do local src = line - local im = Loader.__getNormalizedImage(src) + local im = Loader.__getNormalizedImage(src):t() local gt = src:gsub(".png", ".gt.txt") local cf = assert(io.open(gt, "r")) diff --git a/main.lua b/main.lua index 973257b..f1bbd93 100644 --- a/main.lua +++ b/main.lua @@ -7,13 +7,35 @@ require 'loader' require 'ctc_log' require 'utils/decoder' -mnist = require 'mnist' +base = 0 + +timer = torch.Timer() + +function show_log(log) + local now = timer:time().real + local cost = now - base + base = now + print(string.format("[%.4f][%.4f]%s", now, cost, log)) +end DROPOUT_RATE = 0.4 local input_size = 64 local hidden_size = 100 -local class_num = 10 + + + +show_log("Loading samples...") + +loader = Loader() +loader:load("1.txt") +codec = loader:codec() + +show_log(string.format("Loading finished. Got %d samples, %d classes of characters.", #loader.samples, codec.codec_size)) + +local class_num = codec.codec_size + +show_log("Building networks...") local net = nn.Sequential() @@ -33,32 +55,18 @@ torch.manualSeed(450) params, grad_params = net:getParameters() state = { - learningRate = 1e-3, - momentum = 0.5 + learningRate = 1e-4, + momentum = 0.9 } -loader = Loader() -loader:load("1.txt") -codec = loader:codec() - -local sample = loader:pick() -local im = sample.img -local target = codec:encode(sample.gt) - -raw = image.load(sample.src, 1) - -print(raw[1]) +show_log(string.format("Start training with learning rate = %.4f, momentum = %.4f", state.learningRate, state.momentum)) -print(im) ---[[ for i = 1, 100000 do local sample = loader:pick() local im = sample.img local target = codec:encode(sample.gt) - print(im) - local feval = function(params) net:forget() @@ -66,17 +74,22 @@ for i = 1, 100000 do loss, grad = ctc.getCTCCostAndGrad(outputTable, target) - if i % 20 == 0 then - print(sample.gt) - print(decoder.best_path_decode(outputTable)) - print(loss) + if i % 10 == 0 then + print("") + show_log("EPOCH\t" .. i) + show_log("TARGET\t" .. sample.gt) + show_log("OUTPUT\t" .. decoder.best_path_decode(outputTable, codec)) + show_log("LOSS\t" .. loss) end -- net:zeroGradParameters() + -- print(grad_params) + net:backward(im, grad) grad_params:cmul(torch.eq(grad_params, grad_params):double()) + grad_params:clamp(-5, 5) return loss, grad_params end @@ -85,4 +98,3 @@ for i = 1, 100000 do end -]] diff --git a/normalizer.cc b/normalizer.cc index 9ec6553..5fdac80 100644 --- a/normalizer.cc +++ b/normalizer.cc @@ -106,23 +106,35 @@ static double bilinear(double * in, int w, int h, double x, double y) { // printf("(%d, %d)\n", xi, yi); + if (xi > w - 1 || yi > h - 1 || x < 0 || y < 0) { + return 0; + } + xi = xi < 0 ? 0 : xi; yi = yi < 0 ? 0 : yi; + + xi = xi > w - 1 ? w - 1 : xi; yi = yi > h - 1 ? h - 1 : yi; + xt = xt > w - 1 ? w - 1 : xt; yt = yt > h - 1 ? h - 1 : yt; - printf("(%d, %d)\n", xi, yi); + double p00 = in[yi * w + xi]; double p01 = in[yt * w + xi]; double p10 = in[yi * w + xt]; double p11 = in[yt * w + xt]; - p00 * (1.0 - xf) * (1.0 - yf) + p10 * xf * (1.0 - yf) + p01 * (1.0 - xf) * yf + p11 * xf * yf; + double result = p00 * (1.0 - xf) * (1.0 - yf) + p10 * xf * (1.0 - yf) + p01 * (1.0 - xf) * yf + p11 * xf * yf; + if (result < 0) { + printf("warning result < 0. %.4lf, %.4lf, %.4f, %.4f\n" \ + "%.4f, %.4f, %.4f, %.4f\n", x, y, xf, yf, p00, p01, p10, p11); + } + return result; } static void measure(THDoubleTensor * src, double * & center, double & mean, int & r) { @@ -159,7 +171,7 @@ static void measure(THDoubleTensor * src, double * & center, double & mean, int mean = sy / s1; r = int(mean * RANGE_RATE + 1); - printf("mean = %lf r = %d\n", mean, r); + /* printf("mean = %lf r = %d\n", mean, r); */ } static void normalize diff --git a/test_normalizer.lua b/test_normalizer.lua index f4d3d0b..6d00407 100644 --- a/test_normalizer.lua +++ b/test_normalizer.lua @@ -1,7 +1,7 @@ require 'image' require 'normalizer' -im = image.load("bq01_006.png", 1) +im = image.load("26.png", 1) if im:dim() == 3 then im = im[1] diff --git a/utils/decoder.lua b/utils/decoder.lua index c66de16..ce51e1e 100644 --- a/utils/decoder.lua +++ b/utils/decoder.lua @@ -3,8 +3,8 @@ decoder = {} function decoder.best_path_decode(outputTable, codec) local result = {} - local class_num = #(outputTable[1])[1] + local class_num = outputTable[1]:size()[1] local last_max_class = nil; local last_max = -1; @@ -12,18 +12,16 @@ function decoder.best_path_decode(outputTable, codec) local max_val, max = torch.max(outputTable[i], 1) max = max[1] - if max == class_num then - if last_max ~= -1 and last_max_class ~= nil then - table.insert(result, last_max_class) - last_max = -1 - last_max_class = nil - end - else - if max_val > last_max then - last_max = max_val - last_max_class = max + max_val = max_val[1] + + if max ~= last_max_class then + if max ~= class_num then + table.insert(result, max) end + last_max_class = max end + + end return codec:decode(result)