Skip to content

Commit 4dcdc6b

Browse files
author
Pannous
committed
1 parent 5e5a5e3 commit 4dcdc6b

File tree

4 files changed

+24
-11
lines changed

4 files changed

+24
-11
lines changed

extensions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,16 @@ def is_dir(x, must_exist=True):
227227
# print "\n"
228228
# x
229229

230+
231+
def exists(x):
232+
return os.path.isfile(x)
233+
234+
def is_dir(x):
235+
return os.path.isfile(x) and os.path.isdir(x)
236+
237+
def is_file(x):#Is it a file, or a directory?
238+
return os.path.isfile(x) and not os.path.isdir(x)
239+
230240
def is_a(self, clazz):
231241
if self is clazz: return True
232242
try:

mouse_prediction.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def get_mouse_position():
5454
# todo make model robust to extra text
5555
argmax = numpy.argmax(lines) # most white
5656
argmin = numpy.argmin(lines) # most black
57-
if(argmax<argmin):
58-
mat[:,:argmax,:]=1. # fill white above
59-
if(argmin<argmax):
60-
mat[:,argmax:,:]=1. # fill white below
57+
# if(argmax<argmin):
58+
# mat[:,:argmax,:]=1. # fill white above
59+
# if(argmin<argmax):
60+
# mat[:,argmax:,:]=1. # fill white below
6161
# todo: what if invert image!?
6262

6363
tensor = mat
@@ -87,7 +87,7 @@ def get_mouse_position():
8787
histogram = numpy.histogram(mat, bins=10, range=None, normed=False, weights=None, density=None)
8888
print(argmax)
8989

90-
words = predict_tensor([tensor])
90+
words = predict_tensor(tensor)
9191
if len(words) > 0:
9292
best = words[0]
9393
else:

text_recognizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
import sys
55

66
import numpy as np
7+
import keras.models
78
from PIL import Image # python -m pip install --upgrade Pillow # WTF
8-
from keras.models import load_model
99

1010
# weight_file = 'best_weights.h5'
1111
# weight_file = 'current_weights.h5'
1212

1313
# weight_file = 'weights_ascii.h5' # learned on noisy data
1414
# weight_file = 'weights_ascii_easy.h5' # no freckles
15-
weight_file = 'weights_ascii_clean.h5' # pure text
15+
# weight_file = 'weights_ascii_clean.h5' # pure text
16+
model_file = 'model1000.h5'
1617

1718
chars = u'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZäöüÄÖÜß0123456789!@#$%^&*()[]{}-_=+\\|"\'`;:/.,?><~ '
1819

@@ -22,7 +23,7 @@
2223

2324
def load_model():
2425
global model
25-
model = load_model(weight_file)
26+
model = keras.models.load_model(model_file)
2627
# model.load_weights(weight_file, reshape=True, by_name=True)
2728

2829
def predict_tensor(tensor):
@@ -33,8 +34,9 @@ def predict_tensor(tensor):
3334
tensor = tensor.transpose((2, 1, 0)) # 4*w*h
3435
tensor = tensor[:, :, :, np.newaxis]
3536

37+
print(tensor.shape)
3638
if not model: load_model()
37-
prediction = model.predict(tensor, batch_size=1, verbose=1)
39+
prediction = model.predict([tensor], batch_size=1, verbose=1)
3840
result = decode_results(prediction)
3941
return result
4042

@@ -72,6 +74,5 @@ def decode_results(prediction):
7274
# image = image.transpose(Image.FLIP_TOP_BOTTOM)
7375
tensor = np.array(image) / 255.0 # RGBA: h*w*4
7476
print(tensor.shape)
75-
76-
words = predict_tensor([tensor])
77+
words = predict_tensor(tensor)
7778
print(words)

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def train(run_name, start_epoch, stop_epoch, img_w):
565565
# sgd = Adam()
566566

567567
model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)
568+
model.save(os.path.join(MODEL_DIR, 'model%03d.h5' % (start_epoch + 1)))
568569

569570
# for l in model.layers:
570571
# if not "conv" in l.name and not "dense1" in l.name:
@@ -591,6 +592,7 @@ def train(run_name, start_epoch, stop_epoch, img_w):
591592
def last_epoch():
592593
maxi=0
593594
for date in os.listdir(MODEL_DIR):
595+
if not os.path.isdir(date): continue
594596
for f in os.listdir(MODEL_DIR+"/"+date):
595597
if not f.startswith("weights"): continue
596598
if len(f)==12:

0 commit comments

Comments
 (0)