Skip to content

Commit

Permalink
add normalizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Edward Zhu committed Aug 13, 2015
1 parent b2c9862 commit 82e7f99
Show file tree
Hide file tree
Showing 11 changed files with 736 additions and 33 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
1/*
samples/*
*.png
*.tif

.DS_Store
*.dSYM

# Compiled Lua sources
luac.out
Expand Down
4 changes: 4 additions & 0 deletions ctc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
CMAKE_POLICY(VERSION 2.6)
FIND_PACKAGE(Torch REQUIRED)
SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
36 changes: 36 additions & 0 deletions ctc/ctc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
extern "C" {
#include <luaT.h>
#include <TH/TH.h>
}

#include <cmath>
#include <iostream>
#include <vector>

static int ctc_print(lua_State * L)
{
THDoubleTensor * input = (THDoubleTensor *)luaT_checkudata(L, 1, "torch.DoubleTensor");
int h = input->size[0];
int w = input->size[1];

double * data = THDoubleTensor_data(input);

for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
printf("%.4lf\t", data[i * w + j]);
}
printf("\n");
}

return 0;
}

static const struct luaL_reg ctc[] = {
{"ctc_print", ctc_print},
{NULL, NULL}
};

LUA_EXTERNC int luaopen_ctc(lua_State *L) {
luaL_openlib(L, "ctc", ctc, 0);
return 1;
}
1 change: 1 addition & 0 deletions ctc_log.lua
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ function ctc.getCTCCostAndGrad(outputTable, target)

-- calculate log(p(z|x))
local pzx = logs.log_add(fvs[T][L_1], fvs[T][L_1-1])


-- calculate backwardVariable (in log space)
local bvs= ctc.__getBackwardVariable(outputTable, alignedTable, targetMatrix)
Expand Down
54 changes: 54 additions & 0 deletions gaussian_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from scipy import misc, ndimage
from scipy.ndimage import filters
import matplotlib.pyplot as plt
from pylab import *
from numpy import *
import PIL



def pil2array(im,alpha=0):
if im.mode=="L":
a = fromstring(im.tostring(),'B')
a.shape = im.size[1],im.size[0]
return a
if im.mode=="RGB":
a = fromstring(im.tostring(),'B')
a.shape = im.size[1],im.size[0],3
return a
if im.mode=="RGBA":
a = fromstring(im.tostring(),'B')
a.shape = im.size[1],im.size[0],4
if not alpha: a = a[:,:,:3]
return a
return pil2array(im.convert("L"))

im = PIL.Image.open('bq01_006-1.png')
im = pil2array(im)

im = im / 255.0

print(im)


h = im.shape[0]
w = im.shape[1]

smooth = filters.gaussian_filter(im, (h * 0.5, h * 1.0), mode='constant')

smooth += 0.001*filters.uniform_filter(smooth, (h*0.5, w), mode='constant')

print(smooth.shape)

a = argmax(smooth, axis=0)
a = filters.gaussian_filter(a, h * 0.3)

center = array(a,'i')
# print(center)
deltas = abs(arange(h)[:, newaxis] - center[newaxis, :])
mad = mean(deltas[im != 0])
r = int(1 + 4 * mad)

plt.imshow(smooth, cmap=cm.gray)
plot(center)
plt.show()
50 changes: 22 additions & 28 deletions main.lua
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
require 'nn'
require 'cunn'
require 'rnn'
require 'image'
require 'optim'

require 'ctc_log'
require 'utils/decoder'

mnist = require 'mnist'

DROPOUT_RATE = 0.4
MOMENTUM = 0.9
LEARNING_RATE = 1e-3
MAXOUTNORM = 2

local input_size = 28
local hidden_size = 100
Expand All @@ -34,42 +31,39 @@ torch.manualSeed(450)

params, grad_params = net:getParameters()

state = {
learningRate = 1e-3,
momentum = 0.5
}

for i = 1, 100000 do
local no = torch.random() % 10 + 1
local no = torch.random() % 100 + 1
local sample = mnist.traindataset()[no]
local im = sample.x:double():t()
local target = torch.Tensor{sample.y + 1}


local result = torch.zeros(im:size()[2])
local feval = function(params)
net:forget()

net:forget()
outputTable = net:forward(im)

outputTable = net:forward(im)
loss, grad = ctc.getCTCCostAndGrad(outputTable, target)

if i % 20 == 0 then
print(target[1] - 1)
print(decoder.decodeTable(outputTable))
print(loss)
end

loss, grad = ctc.getCTCCostAndGrad(outputTable, target)
-- net:zeroGradParameters()

net:backward(im, grad)

grad_params:cmul(torch.eq(grad_params, grad_params):double())

if i % 20 == 0 then
print(target[1] - 1)
print(decoder.decodeTable(outputTable))
print(loss)
return loss, grad_params
end

-- net:zeroGradParameters()


net:updateGradInput(im, grad)
net:accGradParameters(im, grad, MOMENTUM)
grad_params:clamp(-10, 10)


-- print(gradParams)

net:updateParameters(LEARNING_RATE)
-- net:maxParamNorm(2)



optim.sgd(feval, params, state)
end
Loading

0 comments on commit 82e7f99

Please sign in to comment.