-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_xainano_images.py
96 lines (75 loc) · 2.76 KB
/
generate_xainano_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from xainano_graphics import postprocessor
from trainer.defaults import *
from utilities import parse_arg
from numpy.random import seed
from os import path, makedirs, remove
import png
import multiprocessing as mp
import cv2
dir_name = 'xainano_images'
data_base_dir = parse_arg('--data-base-dir', '/Users/balazs/university/')
output_dir = parse_arg('--output-dir', '/Users/balazs/university/handwritten_images')
number_of_images = int(parse_arg('--count', 200000))
ncores = int(parse_arg('--ncores', 4))
data_file = "data.txt"
filename_format = 'formula_{:06d}.png'
data = ""
images_path = path.join(output_dir, dir_name, "images")
if not path.exists(images_path):
makedirs(images_path)
def image_file_saver(q):
'''listens for messages on the q, writes image to file. '''
f = open(path.join(output_dir, dir_name, data_file), "w")
while 1:
(file_path, text, image) = q.get()
if file_path == "\n":
break
try:
png.from_array(image, 'RGB').save(file_path)
f.write(text + "\n")
f.flush()
except Exception:
print("There was an exception. Arrgh")
if path.exists(file_path):
print("Deleting file")
remove(file_path)
f.close()
def worker_thread(index):
'''Genrates formulas and images and posts them to a queue '''
tokens = []
worker_thread.generator.generate_formula(tokens, worker_thread.config)
image = worker_thread.token_parser.parse(tokens, worker_thread.post_processor)
filename = filename_format.format(index)
file_path = path.join(images_path, filename)
worker_thread.queue.put((file_path, filename + "\t" + ''.join(tokens), image))
def worker_init(queue):
seed()
worker_thread.queue = queue
worker_thread.post_processor = postprocessor.Postprocessor()
worker_thread.generator = create_generator()
worker_thread.config = create_config()
worker_thread.token_parser = create_token_parser(data_base_dir)
def main():
manager = mp.Manager()
queue = manager.Queue()
file_image_pool = mp.Pool(1)
file_image_pool.apply_async(image_file_saver, (queue,))
worker_pool = mp.Pool(mp.cpu_count() * 2, worker_init, [queue])
worker_pool.map(worker_thread, range(number_of_images))
worker_pool.close()
file_image_pool.close()
worker_pool.join()
queue.join()
queue.put(("\n", "\n"))
file_image_pool.join()
def test_images():
worker_init(None)
while True:
tokens = []
worker_thread.generator.generate_formula(tokens, worker_thread.config)
image = worker_thread.token_parser.parse(tokens, worker_thread.post_processor)
cv2.imshow("Image", image)
cv2.waitKey(0)
if __name__ == '__main__':
# main()
test_images()