forked from lumina37/rotate-captcha-crack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_RotNet.py
52 lines (40 loc) · 1.74 KB
/
test_RotNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
from pathlib import Path
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from rotate_captcha_crack.common import device
from rotate_captcha_crack.criterion import dist_onehot
from rotate_captcha_crack.dataset import ImgTsSeqFromPath, ValDataset
from rotate_captcha_crack.helper import default_num_workers
from rotate_captcha_crack.model import RotNet, WhereIsMyModel
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--index", "-i", type=int, default=-1, help="Use which index")
opts = parser.parse_args()
with torch.no_grad():
dataset_root = Path("./datasets/use_img")
img_paths = list(dataset_root.glob('*.png'))
test_dataset = ValDataset(ImgTsSeqFromPath(img_paths))
test_dataloader = DataLoader(
test_dataset,
batch_size=128,
num_workers=default_num_workers(),
drop_last=True,
)
model = RotNet(train=False)
model_path = WhereIsMyModel(model).with_index(opts.index).model_dir / "best.pth"
print(f"Use model: {model_path}")
model.load_state_dict(torch.load(str(model_path)))
model.to(device=device)
model.eval()
total_degree_diff = 0.0
batch_count = 0
for source, target in test_dataloader:
source: Tensor = source.to(device=device)
target: Tensor = target.to(device=device)
predict: Tensor = model(source)
digree_diff = dist_onehot(predict, target) * predict.shape[1]
total_degree_diff += digree_diff
batch_count += 1
print(f"test_loss: {total_degree_diff/batch_count:.4f} degrees")