Skip to content

Commit

Permalink
basically works.
Browse files Browse the repository at this point in the history
  • Loading branch information
Edward Zhu committed Aug 17, 2015
1 parent 1b1750e commit 9076a60
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 45 deletions.
1 change: 1 addition & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cd build && make && cd ..
23 changes: 20 additions & 3 deletions ctc_log.lua
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function ctc.__getFilledTargetFromString(target)
end

function ctc.__getFilledTarget(target)
local filled = torch.zeros((#target)[1] * 2 + 1)
local filled = torch.zeros(#target * 2 + 1)
for i = 1, (#filled)[1] do
if i % 2 == 0 then
filled[i] = target[i / 2]
Expand Down Expand Up @@ -205,12 +205,20 @@ function ctc.getCTCCostAndGrad(outputTable, target)


targetClasses = ctc.__getFilledTarget(target)

-- print(targetClasses)

targetMatrix = ctc.__getOnehotMatrix(targetClasses, class_num)



outputTable = ctc.__toMatrix(outputTable, class_num)

outputTable = outputTable:cmax(1e-4)
local total = outputTable:sum(2):expand(outputTable:size()[1], outputTable:size()[2])
outputTable = torch.cdiv(outputTable, total)


-- print(outputTable)

for i = 1, (#outputTable)[1] do
Expand All @@ -221,6 +229,8 @@ function ctc.getCTCCostAndGrad(outputTable, target)





-- get aligned_table
-- outputTable: Tx(cls+1)
-- target: L'x(cls+1) --> targetT : (cls+1)xL'
Expand All @@ -240,12 +250,19 @@ function ctc.getCTCCostAndGrad(outputTable, target)
local bvs= ctc.__getBackwardVariable(outputTable, alignedTable, targetMatrix)

local fb = fvs + bvs

-- calculate gradient matrix (Tx(cls+1))
local grad = ctc.__getGrad(fb, pzx, class_num, outputTable, targetClasses)

--[[
print("=========FVS=========")
print(fvs:t())
print("=========BVS=========")
print(bvs:t())
print("=========GRAD=========")
print(grad)
]]

-- print(grad)
grad = nn.SplitTable(1):forward(grad)

return -pzx, grad
Expand Down
4 changes: 1 addition & 3 deletions loader.lua
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ function Loader.__getNormalizedImage(src)
ones = torch.ones(h, w)

im = ones - im

normalizer.normalize(im, output)

return output
end

Expand All @@ -55,7 +53,7 @@ function Loader:load(file)
local f = assert(io.open(file, "r"))
for line in f:lines() do
local src = line
local im = Loader.__getNormalizedImage(src)
local im = Loader.__getNormalizedImage(src):t()

local gt = src:gsub(".png", ".gt.txt")
local cf = assert(io.open(gt, "r"))
Expand Down
60 changes: 36 additions & 24 deletions main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,35 @@ require 'loader'
require 'ctc_log'
require 'utils/decoder'

mnist = require 'mnist'
base = 0

timer = torch.Timer()

function show_log(log)
local now = timer:time().real
local cost = now - base
base = now
print(string.format("[%.4f][%.4f]%s", now, cost, log))
end

DROPOUT_RATE = 0.4

local input_size = 64
local hidden_size = 100
local class_num = 10



show_log("Loading samples...")

loader = Loader()
loader:load("1.txt")
codec = loader:codec()

show_log(string.format("Loading finished. Got %d samples, %d classes of characters.", #loader.samples, codec.codec_size))

local class_num = codec.codec_size

show_log("Building networks...")

local net = nn.Sequential()

Expand All @@ -33,50 +55,41 @@ torch.manualSeed(450)
params, grad_params = net:getParameters()

state = {
learningRate = 1e-3,
momentum = 0.5
learningRate = 1e-4,
momentum = 0.9
}

loader = Loader()
loader:load("1.txt")
codec = loader:codec()

local sample = loader:pick()
local im = sample.img
local target = codec:encode(sample.gt)

raw = image.load(sample.src, 1)

print(raw[1])
show_log(string.format("Start training with learning rate = %.4f, momentum = %.4f", state.learningRate, state.momentum))

print(im)

--[[
for i = 1, 100000 do
local sample = loader:pick()
local im = sample.img
local target = codec:encode(sample.gt)

print(im)
local feval = function(params)
net:forget()

outputTable = net:forward(im)

loss, grad = ctc.getCTCCostAndGrad(outputTable, target)

if i % 20 == 0 then
print(sample.gt)
print(decoder.best_path_decode(outputTable))
print(loss)
if i % 10 == 0 then
print("")
show_log("EPOCH\t" .. i)
show_log("TARGET\t" .. sample.gt)
show_log("OUTPUT\t" .. decoder.best_path_decode(outputTable, codec))
show_log("LOSS\t" .. loss)
end

-- net:zeroGradParameters()

-- print(grad_params)

net:backward(im, grad)

grad_params:cmul(torch.eq(grad_params, grad_params):double())
grad_params:clamp(-5, 5)

return loss, grad_params
end
Expand All @@ -85,4 +98,3 @@ for i = 1, 100000 do
end


]]
18 changes: 15 additions & 3 deletions normalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,35 @@ static double bilinear(double * in, int w, int h, double x, double y) {

// printf("(%d, %d)\n", xi, yi);

if (xi > w - 1 || yi > h - 1 || x < 0 || y < 0) {
return 0;
}

xi = xi < 0 ? 0 : xi;
yi = yi < 0 ? 0 : yi;


xi = xi > w - 1 ? w - 1 : xi;
yi = yi > h - 1 ? h - 1 : yi;


xt = xt > w - 1 ? w - 1 : xt;
yt = yt > h - 1 ? h - 1 : yt;

printf("(%d, %d)\n", xi, yi);


double p00 = in[yi * w + xi];
double p01 = in[yt * w + xi];
double p10 = in[yi * w + xt];
double p11 = in[yt * w + xt];

p00 * (1.0 - xf) * (1.0 - yf) + p10 * xf * (1.0 - yf) + p01 * (1.0 - xf) * yf + p11 * xf * yf;
double result = p00 * (1.0 - xf) * (1.0 - yf) + p10 * xf * (1.0 - yf) + p01 * (1.0 - xf) * yf + p11 * xf * yf;
if (result < 0) {
printf("warning result < 0. %.4lf, %.4lf, %.4f, %.4f\n" \
"%.4f, %.4f, %.4f, %.4f\n", x, y, xf, yf, p00, p01, p10, p11);
}

return result;
}

static void measure(THDoubleTensor * src, double * & center, double & mean, int & r) {
Expand Down Expand Up @@ -159,7 +171,7 @@ static void measure(THDoubleTensor * src, double * & center, double & mean, int
mean = sy / s1;
r = int(mean * RANGE_RATE + 1);

printf("mean = %lf r = %d\n", mean, r);
/* printf("mean = %lf r = %d\n", mean, r); */
}

static void normalize
Expand Down
2 changes: 1 addition & 1 deletion test_normalizer.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
require 'image'
require 'normalizer'

im = image.load("bq01_006.png", 1)
im = image.load("26.png", 1)

if im:dim() == 3 then
im = im[1]
Expand Down
20 changes: 9 additions & 11 deletions utils/decoder.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,25 @@ decoder = {}
function decoder.best_path_decode(outputTable, codec)
local result = {}

local class_num = #(outputTable[1])[1]

local class_num = outputTable[1]:size()[1]
local last_max_class = nil;
local last_max = -1;

for i = 1, #outputTable do
local max_val, max = torch.max(outputTable[i], 1)
max = max[1]

if max == class_num then
if last_max ~= -1 and last_max_class ~= nil then
table.insert(result, last_max_class)
last_max = -1
last_max_class = nil
end
else
if max_val > last_max then
last_max = max_val
last_max_class = max
max_val = max_val[1]

if max ~= last_max_class then
if max ~= class_num then
table.insert(result, max)
end
last_max_class = max
end


end

return codec:decode(result)
Expand Down

0 comments on commit 9076a60

Please sign in to comment.