-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtest.py
62 lines (52 loc) · 1.78 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""This module is used to test the YeNet model."""
from glob import glob
import torch
import numpy as np
import imageio as io
from model import YeNet
TEST_BATCH_SIZE = 40
COVER_PATH = "/path/to/cover/images/"
STEGO_PATH = "/path/to/stego/images/"
CHKPT = "./checkpoints/YeNet_model_weights.pt"
cover_image_names = glob(COVER_PATH)
stego_image_names = glob(STEGO_PATH)
cover_labels = np.zeros((len(cover_image_names)))
stego_labels = np.ones((len(stego_image_names)))
model = YeNet().cuda()
ckpt = torch.load(CHKPT)
model.load_state_dict(ckpt["model_state_dict"])
# pylint: disable=E1101
images = torch.empty((TEST_BATCH_SIZE, 1, 256, 256), dtype=torch.float)
# pylint: enable=E1101
test_accuracy = []
for idx in range(0, len(cover_image_names), TEST_BATCH_SIZE // 2):
cover_batch = cover_image_names[idx : idx + TEST_BATCH_SIZE // 2]
stego_batch = stego_image_names[idx : idx + TEST_BATCH_SIZE // 2]
batch = []
batch_labels = []
xi = 0
yi = 0
for i in range(2 * len(cover_batch)):
if i % 2 == 0:
batch.append(stego_batch[xi])
batch_labels.append(1)
xi += 1
else:
batch.append(cover_batch[yi])
batch_labels.append(0)
yi += 1
# pylint: disable=E1101
for i in range(TEST_BATCH_SIZE):
images[i, 0, :, :] = torch.tensor(io.imread(batch[i])).cuda()
image_tensor = images.cuda()
batch_labels = torch.tensor(batch_labels, dtype=torch.long).cuda()
# pylint: enable=E1101
outputs = model(image_tensor)
prediction = outputs.data.max(1)[1]
accuracy = (
prediction.eq(batch_labels.data).sum()
* 100.0
/ (batch_labels.size()[0])
)
test_accuracy.append(accuracy.item())
print(f"test_accuracy = {sum(test_accuracy)/len(test_accuracy):%.2f}")