|
| 1 | +from os import makedirs, path as op |
| 2 | +from shutil import copytree |
| 3 | +from collections import Counter |
| 4 | +import csv |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from PIL import Image |
| 8 | + |
| 9 | +# create a greyscale folder for class labelled images |
| 10 | +greyscale_folder = op.join('labels', 'grayscale') |
| 11 | +if not op.isdir(greyscale_folder): |
| 12 | + makedirs(greyscale_folder) |
| 13 | +labels = np.load('labels.npz') |
| 14 | + |
| 15 | +# write our numpy array labels to images |
| 16 | +# remove empty labels because we don't download images for them |
| 17 | +keys = labels.keys() |
| 18 | +class_freq = Counter() |
| 19 | +image_freq = Counter() |
| 20 | +for key in keys: |
| 21 | + label = labels[key] |
| 22 | + if np.sum(label): |
| 23 | + label_file = op.join(greyscale_folder, '{}.png'.format(key)) |
| 24 | + img = Image.fromarray(label.astype(np.uint8)) |
| 25 | + print('Writing {}'.format(label_file)) |
| 26 | + img.save(label_file) |
| 27 | + # get class frequencies |
| 28 | + unique, counts = np.unique(label, return_counts=True) |
| 29 | + freq = dict(zip(unique, counts)) |
| 30 | + for k, v in freq.items(): |
| 31 | + class_freq[k] += v |
| 32 | + image_freq[k] += 1 |
| 33 | + else: |
| 34 | + keys.remove(key) |
| 35 | + |
| 36 | +# copy our tiles to a folder with a different name |
| 37 | +copytree('tiles', 'images') |
| 38 | + |
| 39 | +# sample the file names and use those to create text files |
| 40 | +np.random.shuffle(keys) |
| 41 | +split_index = int(len(keys) * 0.8) |
| 42 | + |
| 43 | +with open('train.txt', 'w') as train: |
| 44 | + for key in keys[:split_index]: |
| 45 | + train.write('/data/images/{}.png /data/labels/grayscale/{}.png\n'.format(key, key)) |
| 46 | + |
| 47 | +with open('val.txt', 'w') as val: |
| 48 | + for key in keys[split_index:]: |
| 49 | + val.write('/data/images/{}.png /data/labels/grayscale/{}.png\n'.format(key, key)) |
| 50 | + |
| 51 | +# write a csv with class frequencies |
| 52 | +freqs = [dict(label=k, frequency=v, image_count=image_freq[k]) for k, v in class_freq.items()] |
| 53 | +with open('labels/label-stats.csv', 'w') as stats: |
| 54 | + fieldnames = list(freqs[0].keys()) |
| 55 | + writer = csv.DictWriter(stats, fieldnames=fieldnames) |
| 56 | + |
| 57 | + writer.writeheader() |
| 58 | + for f in freqs: |
| 59 | + writer.writerow(f) |
0 commit comments