Skip to content

Commit 48aec99

Browse files
author
GE ZHENG
committed
add test code for inference
1 parent 1c99970 commit 48aec99

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import os
4+
from scipy import misc
5+
from matting_numpy_unpool import generate_trimap
6+
import argparse
7+
import sys
8+
9+
g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])
10+
11+
def main(args):
12+
13+
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_fraction)
14+
with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
15+
saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
16+
saver.restore(sess,tf.train.latest_checkpoint('./model'))
17+
image_batch = tf.get_collection('image_batch')[0]
18+
GT_trimap = tf.get_collection('GT_trimap')[0]
19+
pred_mattes = tf.get_collection('pred_mattes')[0]
20+
21+
rgb = misc.imread(args.rgb)
22+
alpha = misc.imread(args.alpha,'L')
23+
trimap = generate_trimap(np.expand_dims(np.copy(alpha),2),np.expand_dims(alpha,2))[:,:,0]
24+
origin_shape = alpha.shape
25+
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3]).astype(np.float32)-g_mean,0)
26+
trimap = np.expand_dims(np.expand_dims(misc.imresize(trimap.astype(np.uint8),[320,320],interp = 'nearest').astype(np.float32),2),0)
27+
28+
feed_dict = {image_batch:rgb,GT_trimap:trimap}
29+
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
30+
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
31+
misc.imshow(final_alpha)
32+
misc.imsave('./alpha.png',final_alpha)
33+
34+
def parse_arguments(argv):
35+
parser = argparse.ArgumentParser()
36+
37+
parser.add_argument('--alpha', type=str,
38+
help='input alpha')
39+
parser.add_argument('--rgb', type=str,
40+
help='input rgb')
41+
parser.add_argument('--trimap_dilation', type=int,
42+
help='the kernel size that used to generate trimap from alpha, 20~40 are suggested',default = 30)
43+
parser.add_argument('--gpu_fraction', type=float,
44+
help='how much gpu is needed, usually 4G is enough',default = 0.4)
45+
return parser.parse_args(argv)
46+
47+
48+
if __name__ == '__main__':
49+
main(parse_arguments(sys.argv[1:]))
50+

0 commit comments

Comments
 (0)