-
Notifications
You must be signed in to change notification settings - Fork 11
/
test.py
78 lines (60 loc) · 2.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import keras
import sys
from keras.models import Model
from keras.utils import plot_model
from keras.models import load_model
from PIL import Image
import matplotlib.pyplot as plt
from random import randint
import imageio
from skimage.util.shape import view_as_blocks
%matplotlib inline
'''
Test the model on sample images (unseen)
Plot the input and output images
'''
# Load test images
test_images=np.load(sys.argv[1])
# Load model
model=load_model(sys.argv[2],compile=False)
# Normalize inputs
def normalize_batch(imgs):
'''Performs channel-wise z-score normalization'''
return (imgs - np.array([0.485, 0.456, 0.406])) /np.array([0.229, 0.224, 0.225])
# Denormalize outputs
def denormalize_batch(imgs,should_clip=True):
imgs= (imgs * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
if should_clip:
imgs= np.clip(imgs,0,1)
return imgs
# Load images as batch (batch size -4)
secretin = test_images[np.random.choice(len(test_images), size=4, replace=False)]
coverin = test_images[np.random.choice(len(test_images), size=4, replace=False)]
# Perform batch prediction
coverout, secretout=model.predict([normalize_batch(secretin),normalize_batch(coverin)])
# Postprocess cover output
coverout = denormalize_batch(coverout)
coverout=np.squeeze(coverout)*255.0
coverout=np.uint8(coverout)
# Postprocess secret output
secretout=denormalize_batch(secretout)
secretout=np.squeeze(secretout)*255.0
secretout=np.uint8(secretout)
# Convert images to UINT8 format (0-255)
coverin=np.uint8(np.squeeze(coverin*255.0))
secretin=np.uint8(np.squeeze(secretin*255.0))
# Plot the images
def plot(im, title):
fig = plt.figure(figsize=(20, 20))
for i in range(4):
sub = fig.add_subplot(1, 4, i + 1)
sub.title.set_text(title + " " + str(i+1))
sub.imshow(im[i,:,:,:])
# Plot secret input and output
plot(secretin, "Secret Input")
plot(secretout, "Secret Output")
# Plot cover input and output
plot(coverin, "Cover Input")
plot(coverout, "Cover Output")
# Sample run: python test.py test/testdata.npy checkpoints/steg_model-06-0.03.hdf5