Skip to content

Commit

Permalink
Merge branch 'yjyjy131_dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
yjyjy131 committed Jun 26, 2022
2 parents b870cb7 + b7047ff commit 3a2d6bf
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 36 deletions.
4 changes: 2 additions & 2 deletions SPFCN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def slot_network_training(data_num, batch_size, epoch, input_res, device_id=0,
epoch_limit=epoch, save_path="parameters/")


def slot_network_testing(params_path, data_num, batch_size, input_res, device_id=0, num_workers=0):
def slot_network_testing(parameter_path, data_num, batch_size, input_res, device_id=0, num_workers=0):
# Initial
setup(19960229)
net = SlotNetwork([32, 44, 64, 92, 128], device_id)

# Test
auto_test(get_testing_set(data_num, batch_size, input_res, device_id, num_workers=0), net, device_id, params_path)
auto_test(get_testing_set(data_num, batch_size, input_res, device_id, num_workers=0), net, device_id, parameter_path)
4 changes: 2 additions & 2 deletions SPFCN/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ def auto_test(dataset,
device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)

try:
net_path = load_path + '.pkl'
assert os.path.exists(net_path)
network.load_state_dict(torch.load(net_path, map_location=device))
except RuntimeError:
net_path = load_path + '.pt'
net_path = load_path.replace('pkl', '.pt')
assert os.path.exists(net_path)
network = torch.load(net_path, map_location=device)
network= dill.loads(network)

network.eval()

auto_tester = Tester(dataset, network, device)
Expand Down
61 changes: 29 additions & 32 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,16 @@

if __name__ == "__main__":

# auto train
# slot_network_training(data_num=6535, batch_size=10, epoch=10, input_res=224, device_id=0, num_workers=0)

# load trained model
model_path = './parameters/merge_bn_epoch10_loss4.pkl'


# auto test
# slot_network_testing(params_path='./parameters/merge_bn_epoch10_loss4', data_num=1500, batch_size=50, input_res=224, device_id=0, num_workers=0)
### ORIGINAL VERSION ###
# Train model
slot_network_training(data_num=6535, batch_size=10, epoch=10, input_res=224, device_id=0, num_workers=0)

# Test model
params_path = './parameters/merge_bn_epoch10_loss4.pkl'
slot_network_testing(parameter_path=params_path, data_num=1500, batch_size=50, input_res=224, device_id=0, num_workers=0)

# Load detector
detector = SlotDetector(device_id=0, dim_encoder=[32, 44, 64, 92, 128], parameter_path=model_path)
detector = SlotDetector(device_id=0, dim_encoder=[32, 44, 64, 92, 128], parameter_path=params_path)

# Visualize the merge image with result
current_frame = cv2.imread("demo.jpg")
Expand All @@ -40,28 +38,27 @@
cv2.imwrite("result.jpg", current_frame)



### LIGHT VERSION ###
# detector = Detector("./SPFCN_Light/stable_parameter_0914.pkl", device_id=-1)
detector = Detector("./SPFCN_Light/stable_parameter_0914.pkl", device_id=-1)

# for frame_index in range(1000):
# tic = time.time()
# inference_image = cv2.cvtColor(cv2.resize(current_frame, (224, 224)), cv2.COLOR_BGR2GRAY)
# inference_result = detector(inference_image)
# toc = time.time()
# time_span = toc - tic
# infer_fps = 1 / (time_span + 1e-5)
# print("Frame:{:d}, Time used:{:.3f}, FPS:{:.3f}".format(frame_index, time_span * 1000, infer_fps), end='\r')
for frame_index in range(1000):
tic = time.time()
inference_image = cv2.cvtColor(cv2.resize(current_frame, (224, 224)), cv2.COLOR_BGR2GRAY)
inference_result = detector(inference_image)
toc = time.time()
time_span = toc - tic
infer_fps = 1 / (time_span + 1e-5)
print("Frame:{:d}, Time used:{:.3f}, FPS:{:.3f}".format(frame_index, time_span * 1000, infer_fps), end='\r')

# resolution = current_frame.shape[0]
# for detect_result in inference_result:
# pt0 = (int(detect_result[0][0] * resolution / 224), int(detect_result[0][1] * resolution / 224))
# pt1 = (int(detect_result[1][0] * resolution / 224), int(detect_result[1][1] * resolution / 224))
# pt2 = (int(detect_result[2][0] * resolution / 224), int(detect_result[2][1] * resolution / 224))
# pt3 = (int(detect_result[3][0] * resolution / 224), int(detect_result[3][1] * resolution / 224))
# cv2.line(current_frame, pt0, pt1, (0, 255, 0), thickness=2)
# cv2.line(current_frame, pt0, pt2, (0, 0, 255), thickness=2)
# cv2.line(current_frame, pt1, pt3, (0, 0, 255), thickness=2)
# cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2)
# cv2.putText(current_frame, "%.2f fps" % infer_fps, (30, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255))
# cv2.imwrite("result.jpg", current_frame)
resolution = current_frame.shape[0]
for detect_result in inference_result:
pt0 = (int(detect_result[0][0] * resolution / 224), int(detect_result[0][1] * resolution / 224))
pt1 = (int(detect_result[1][0] * resolution / 224), int(detect_result[1][1] * resolution / 224))
pt2 = (int(detect_result[2][0] * resolution / 224), int(detect_result[2][1] * resolution / 224))
pt3 = (int(detect_result[3][0] * resolution / 224), int(detect_result[3][1] * resolution / 224))
cv2.line(current_frame, pt0, pt1, (0, 255, 0), thickness=2)
cv2.line(current_frame, pt0, pt2, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt1, pt3, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2)
cv2.putText(current_frame, "%.2f fps" % infer_fps, (30, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255))
cv2.imwrite("result.jpg", current_frame)

0 comments on commit 3a2d6bf

Please sign in to comment.