Skip to content

Commit

Permalink
test code
Browse files Browse the repository at this point in the history
  • Loading branch information
yjyjy131 committed Jun 26, 2022
1 parent 0ebce7e commit 8632456
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 47 deletions.
14 changes: 6 additions & 8 deletions SPFCN/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from torch.backends import cudnn

from .dataset import get_training_set, get_validating_set
from .dataset import get_training_set, get_validating_set, get_testing_set
from .model.network import SlotNetwork
from .train import auto_train, auto_validate
from .test import auto_test


def setup(seed):
Expand All @@ -23,13 +24,10 @@ def slot_network_training(data_num, batch_size, epoch, input_res, device_id=0,
epoch_limit=epoch, save_path="parameters/")


# TODO
def slot_network_testing(model_path, device_id=0):
def slot_network_testing(params_path, data_num, batch_size, input_res, device_id=0, num_workers=0):
# Initial
setup(19960229)
net = SlotNetwork([32, 44, 64, 92, 128], device_id)

#load model
net = SlotNetwork([32, 44, 64, 92, 128], device_id=device_id)

#Test
auto_test(get_testing_set(), net, ...)
# Test
auto_test(get_testing_set(data_num, batch_size, input_res, device_id, num_workers=0), net, device_id, params_path)
12 changes: 6 additions & 6 deletions SPFCN/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ def get_validating_set(data_size: int,
dataset=vps_set, batch_size=batch_size, shuffle=False)


# TODO
def get_testing_set(data_size: int,
batch_size: int,
resolution: int = 224,
device_id: int = 0):
batch_size: int,
resolution: int = 224,
device_id: int = 0,
num_workers: int = 0.):
assert 0 < data_size < 1538 and 0 < batch_size and 0 < resolution
vps_set = VisionParkingSlotDataset(
image_path="./data/testing/image/",
label_path="./data/testing/label/",
data_size=data_size,
resolution=resolution)
if device_id < 0:
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=4)
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=num_workers)
else:
return DataPrefetcher(device=torch.device('cuda:%d' % device_id),
dataset=vps_set, batch_size=batch_size, shuffle=False)
dataset=vps_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)


1 change: 1 addition & 0 deletions SPFCN/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .network import SlotNetwork
from .detector import SlotDetector
2 changes: 1 addition & 1 deletion SPFCN/model/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class SlotDetector(object):
def __init__(self, device_id: int, **kwargs):
self.device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)
self.config = self.update_config(**kwargs)

print(self.config)
self.network = SlotNetwork(self.config['dim_encoder'], device_id)
self.network.merge()
self.network.load_state_dict(torch.load(self.config['parameter_path'], map_location=self.device))
Expand Down
28 changes: 28 additions & 0 deletions SPFCN/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import torch
import dill
from .tester import Tester


@torch.no_grad()
def auto_test(dataset,
network,
device_id: int = 0,
load_path: str = None):
device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)

try:
net_path = load_path + '.pkl'
assert os.path.exists(net_path)
network.load_state_dict(torch.load(net_path, map_location=device))
except RuntimeError:
net_path = load_path + '.pt'
assert os.path.exists(net_path)
network = torch.load(net_path, map_location=device)
network=dill.loads(network)
network.eval()

auto_tester = Tester(dataset, network, device)
auto_tester.step()
auto_tester.get_network_inference_time()
auto_tester.get_detector_inference_time()
170 changes: 170 additions & 0 deletions SPFCN/test/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from time import time

import cv2
import numpy as np
import torch


class Tester(object):
def __init__(self, dataset, network, device):
self.dataset = dataset
self.network = network.to(device)
self.device = device

self.const_h = torch.ones((1, 224)).to(device)
self.const_w = torch.ones((224, 1)).to(device)
self.mark_threshold = 0.1
self.direct_threshold = 0.95
self.distance_threshold = 40
self.elliptic_coefficient = 1.6

self.mgt_threshold = 4.6
self.iou_threshold = 0.95

def step(self):
self.dataset.refresh()
testing_image, testing_label = self.dataset.next()
index = 0
mark_gt_count, mark_re_count, mark_co_count = 0, 0, 0
slot_gt_count, slot_re_count, slot_co_count = 0, 0, 0
while testing_image is not None and testing_label is not None:
gt_mark = testing_label[0:1, 0:1]
gt_direction = testing_label[0:1, 1:]
gt_mark_count, gt_mark_map, gt_slot_count, gt_slot_list = \
self.slot_detect(gt_mark[0, 0], gt_direction, True)
mark_gt_count += gt_mark_count
slot_gt_count += gt_slot_count

re_mark, re_direction = self.network(testing_image)
re_mark_count, re_mark_map, re_slot_count, re_slot_list = \
self.slot_detect(re_mark[0, 0], re_direction, False)
mark_re_count += re_mark_count
slot_re_count += re_slot_count

for ind in range(re_mark_count):
re_x = int(re_mark_map[ind, 0])
re_y = int(re_mark_map[ind, 1])
angle = re_mark_map[ind, 2] * gt_direction[0, 0, re_x, re_y]
angle += re_mark_map[ind, 3] * gt_direction[0, 1, re_x, re_y]
distance = gt_mark[0, 0, re_x - 1:re_x + 2, re_y - 1:re_y + 2].sum()
if angle > self.direct_threshold and distance > self.mgt_threshold:
mark_co_count += 1

for ind in range(re_slot_count):
re_pt = re_slot_list[ind]
for jnd in range(gt_slot_count):
gt_pt = gt_slot_list[jnd]
mask_gt = cv2.fillConvexPoly(np.zeros([224, 224], dtype="uint8"), np.array(gt_pt), 1)
mask_re = cv2.fillConvexPoly(np.zeros([224, 224], dtype="uint8"), np.array(re_pt), 1)
count_and = np.sum(cv2.bitwise_and(mask_re, mask_gt))
count_or = np.sum(cv2.bitwise_or(mask_re, mask_gt))
if count_and > self.iou_threshold * count_or:
slot_co_count += 1

testing_image, testing_label = self.dataset.next()
index += 1

mark_precision, mark_recall, slot_precision, slot_recall = -1, -1, -1, -1

try:
mark_precision = mark_co_count / mark_re_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")

try:
mark_recall = mark_co_count / mark_gt_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")

try:
slot_precision = slot_co_count / slot_re_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")

try:
slot_recall = slot_co_count / slot_gt_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")

print("\rIndex: {}, Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(index, mark_precision, mark_recall, slot_precision, slot_recall), end='')
print('\r' + ' ' * 50, end="")
print("Total score - Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(mark_precision, mark_recall, slot_precision, slot_recall))

def get_network_inference_time(self):
def foo(img):
_, _ = self.network(img)

print('\rNetwork ' + self.get_inference_time(foo))

def get_detector_inference_time(self):
def foo(img):
mark, direction = self.network(img)
self.slot_detect(mark[0, 0], direction, False)

print('\rDetector ' + self.get_inference_time(foo))

def slot_detect(self, mark, direction, gt=False):
# Mark detection
if gt:
mark_prediction = torch.nonzero(mark == 1)
else:
mark_prediction = torch.nonzero((mark > self.mark_threshold) *
(mark > torch.cat((mark[1:, :], self.const_h), dim=0)) *
(mark > torch.cat((self.const_h, mark[:-1, :]), dim=0)) *
(mark > torch.cat((mark[:, 1:], self.const_w), dim=1)) *
(mark > torch.cat((self.const_w, mark[:, :-1]), dim=1)))

mark_count = len(mark_prediction)
mark_map = torch.zeros([mark_count, 4]).to(self.device)
mark_map[:, 0:2] = mark_prediction
for item in mark_map:
item[2:] = direction[0, :, item[0].int(), item[1].int()]

# Distance map generate
distance_map = torch.zeros([mark_count, mark_count]).to(self.device)
for i in range(0, mark_count - 1):
for j in range(i + 1, mark_count):
if mark_map[i, 2] * mark_map[j, 2] + mark_map[i, 3] * mark_map[j, 3] > self.direct_threshold:
distance = torch.pow(torch.pow(mark_map[i, 0] - mark_map[j, 0], 2) +
torch.pow(mark_map[i, 1] - mark_map[j, 1], 2), 0.5)
distance_map[i, j] = distance
distance_map[j, i] = distance

# Slot check
slot_list = []
for i in range(0, mark_count - 1):
for j in range(i + 1, mark_count):
distance = distance_map[i, j]
if distance > self.distance_threshold and \
(distance_map[i] + distance_map[j] < self.elliptic_coefficient * distance).sum() == 2:
slot_length = 120 if distance < 80 else 60
vx = torch.abs(mark_map[i, 0] - mark_map[j, 0]) / distance
vy = torch.abs(mark_map[i, 1] - mark_map[j, 1]) / distance
delta_x = -slot_length * vx if mark_map[i, 2] < 0 else slot_length * vx
delta_y = -slot_length * vy if mark_map[i, 3] < 0 else slot_length * vy

slot_list.append(((int(mark_map[i, 1]), int(mark_map[i, 0])),
(int(mark_map[j, 1]), int(mark_map[j, 0])),
(int(mark_map[j, 1] + delta_x), int(mark_map[j, 0] + delta_y)),
(int(mark_map[i, 1] + delta_x), int(mark_map[i, 0] + delta_y))))
break

return mark_count, mark_map, len(slot_list), slot_list

def get_inference_time(self, foo):
self.dataset.refresh()
testing_image, _ = self.dataset.next()
foo(testing_image)
index = 0
time_step = 0
while testing_image is not None:
timestamp = time()
foo(testing_image)
time_step += time() - timestamp
testing_image, _ = self.dataset.next()
index += 1
print("\rIndex: {}, Inference Time: {:.1f}ms".format(index, 1e3 * time_step / index), end="")
print('\r' + ' ' * 40, end="")
return "Inference Time: {:.1f}ms".format(1e3 * time_step / index)
11 changes: 6 additions & 5 deletions SPFCN/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import dill
from datetime import datetime

import torch
Expand All @@ -18,10 +19,10 @@ def auto_validate(dataset,
network.load_state_dict(torch.load(load_path, map_location=device))
network.eval()

auto_validator = Validator(dataset, network, device)
auto_validator.step()
auto_validator.get_network_inference_time()
auto_validator.get_detector_inference_time()
auto_tester = Validator(dataset, network, device)
auto_tester.step()
auto_tester.get_network_inference_time()
auto_tester.get_detector_inference_time()


def auto_train(dataset,
Expand Down Expand Up @@ -76,7 +77,7 @@ def auto_train(dataset,
stage = "merge_bn"

torch.save(network.state_dict(), "%s%s_epoch%d_loss%d.pkl" % (save_path, stage, epoch, int(epoch_loss)))
torch.save()
torch.save(dill.dumps(network), "%s%s_epoch%d_loss%d.pt" % (save_path, stage, epoch, int(epoch_loss)))

curr = datetime.now()
info = '{:02d}:{:02d}:{:02d} '.format(curr.hour, curr.minute, curr.minute)
Expand Down
71 changes: 44 additions & 27 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,55 +5,72 @@

from SPFCN_Light.slot_detector import Detector
from SPFCN import slot_network_training, slot_network_testing
from SPFCN.model import SlotNetwork, SlotDetector

if __name__ == "__main__":

# auto train
slot_network_training(data_num=6535, batch_size=10, epoch=10, input_res=224, device_id=0, num_workers=0)
# slot_network_training(data_num=6535, batch_size=10, epoch=10, input_res=224, device_id=0, num_workers=0)

# load trained model
encoder = [32, 44, 64, 92, 128]
net = SlotNetwork(encoder, device_id=0)
model_path = './parameters/merge_bn_epoch10_loss4.pt'
model = torch.load(model_path)
model = dill.loads(model)

# auto test
model_path = './SPFCN/'
slot_network_testing(model_path, device_id=0)
slot_network_testing(params_path='./parameters/merge_bn_epoch10_loss4', data_num=1500, batch_size=50, input_res=224, device_id=0, num_workers=0)

# Load detector
detector = LoadDetector()
config = {
'dim_encoder': encoder,
'parameter_path': model_path,
}
detector = SlotDetector(0, config)

# Visualize the merge image with result
current_frame = cv2.imread("demo.jpg")
inference_image = cv2.resize(current_frame, (224, 224))
inference_result = detector(inference_image)



### LIGHT VERSION ###
# Read image
current_frame = cv2.imread("demo.jpg")

# Initial model
detector = Detector("./SPFCN_Light/stable_parameter_0914.pkl", device_id=-1)

# Start the detection
for frame_index in range(1000):
# Get the result
tic = time.time()
inference_image = cv2.cvtColor(cv2.resize(current_frame, (224, 224)), cv2.COLOR_BGR2GRAY)
inference_result = detector(inference_image)
toc = time.time()
time_span = toc - tic
infer_fps = 1 / (time_span + 1e-5)
print("Frame:{:d}, Time used:{:.3f}, FPS:{:.3f}".format(frame_index, time_span * 1000, infer_fps), end='\r')

# Visualize the merge image with result
resolution = current_frame.shape[0]
for detect_result in inference_result:
pt0 = (int(detect_result[0][0] * resolution / 224), int(detect_result[0][1] * resolution / 224))
pt1 = (int(detect_result[1][0] * resolution / 224), int(detect_result[1][1] * resolution / 224))
pt2 = (int(detect_result[2][0] * resolution / 224), int(detect_result[2][1] * resolution / 224))
pt3 = (int(detect_result[3][0] * resolution / 224), int(detect_result[3][1] * resolution / 224))
cv2.line(current_frame, pt0, pt1, (0, 255, 0), thickness=2)
cv2.line(current_frame, pt0, pt2, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt1, pt3, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt0, pt3, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt1, pt2, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2)
cv2.putText(current_frame, "%.2f fps" % infer_fps, (30, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255))

cv2.imwrite("result.jpg", current_frame)



### LIGHT VERSION ###
# detector = Detector("./SPFCN_Light/stable_parameter_0914.pkl", device_id=-1)

# for frame_index in range(1000):
# tic = time.time()
# inference_image = cv2.cvtColor(cv2.resize(current_frame, (224, 224)), cv2.COLOR_BGR2GRAY)
# inference_result = detector(inference_image)
# toc = time.time()
# time_span = toc - tic
# infer_fps = 1 / (time_span + 1e-5)
# print("Frame:{:d}, Time used:{:.3f}, FPS:{:.3f}".format(frame_index, time_span * 1000, infer_fps), end='\r')

# resolution = current_frame.shape[0]
# for detect_result in inference_result:
# pt0 = (int(detect_result[0][0] * resolution / 224), int(detect_result[0][1] * resolution / 224))
# pt1 = (int(detect_result[1][0] * resolution / 224), int(detect_result[1][1] * resolution / 224))
# pt2 = (int(detect_result[2][0] * resolution / 224), int(detect_result[2][1] * resolution / 224))
# pt3 = (int(detect_result[3][0] * resolution / 224), int(detect_result[3][1] * resolution / 224))
# cv2.line(current_frame, pt0, pt1, (0, 255, 0), thickness=2)
# cv2.line(current_frame, pt0, pt2, (0, 0, 255), thickness=2)
# cv2.line(current_frame, pt1, pt3, (0, 0, 255), thickness=2)
# cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2)
# cv2.putText(current_frame, "%.2f fps" % infer_fps, (30, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255))
# cv2.imwrite("result.jpg", current_frame)

0 comments on commit 8632456

Please sign in to comment.