-
Notifications
You must be signed in to change notification settings - Fork 3
/
mask_generate.py
74 lines (57 loc) · 2.58 KB
/
mask_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sys
import random
import math
import numpy as np
import skimage.io
from PIL import Image
from tqdm import tqdm
import argparse
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from coco_config import *
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='./images', help='Image dir for mask generation')
parser.add_argument('--root_dir', type=str, default='./', help='root directory of project')
parser.add_argument('--out_dir', type=str, default='./masks', help='directory of mask to save')
parser.add_argument('--model_dir', type=str, default="./models/mask_rcnn_coco.h5", help='Local path to trained weights file')
# parser.add_argument('--object_list', type= nargs='+', default='car', 'truck', help='objects to segment, to be provided as list of str, should be from class_list')
parser.add_argument('--is_resize', type=bool, default=False, help='to resize segmented mask')
parser.add_argument('--resize_dim', type=int, default=256, help='size of the data resize (squared assumed)')
object_list = ['car', 'truck']
args = parser.parse_args()
print(args)
# Download COCO trained weights from Releases if needed
if not os.path.exists(args.model_dir):
utils.download_trained_weights(args.model_dir)
# Directory of images to run detection on
# IMAGE_DIR = '/silocloud/buckets/silo_data/train/A'
# MASK_DIR = '/silocloud/buckets/silo_data/mask/A'
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
config = InferenceConfig()
print('Configuration Input is')
print(config)
# Create model object in inference mode.
model = modellib.MaskRCNN(mode="inference", model_dir=args.root_dir, config=config)
# Load weights trained on MS-COCO
model.load_weights(args.model_dir, by_name=True)
# get file names from image dir to load
file_names= [file for file in os.listdir(args.image_dir) if file.endswith('.jpg')]
# get id list from class of objects
id_list = [class_names.index(name) for name in object_list]
# Run detection and save masks
for file_name in tqdm(file_names):
image = skimage.io.imread(os.path.join(args.image_dir, file_name))
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
results = model.detect([image])
class_ids = results[0]['class_ids']
for id_ in range(len(class_ids)):
if class_ids[id_] in id_list:
mask += results[0]['masks'][:,:,id_]
mask[mask > 0] = 255
img = Image.fromarray(mask)
if args.is_resize:
img = img.resize((args.resize_dim,args.resize_dim), Image.ANTIALIAS)
img.save(os.path.join(args.out_dir, 'mask_' + file_name))