Skip to content

Commit

Permalink
modify validate & test & inference
Browse files Browse the repository at this point in the history
  • Loading branch information
yjyjy131 committed Jul 21, 2022
1 parent ede38f1 commit 70c987d
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 197 deletions.
2 changes: 1 addition & 1 deletion SPFCN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def slot_network_training(data_num, batch_size, valid_data_num, valid_batch_size
epoch_limit=epoch, save_path="parameters/")


def slot_network_testing(parameter_path, data_num, batch_size, input_res, device_id=0, num_workers=0):
def slot_network_testing(parameter_path, data_num, batch_size, input_res, device_id=0):
# Initial
setup(19960229)
net = SlotNetwork([32, 44, 64, 92, 128], device_id)
Expand Down
6 changes: 3 additions & 3 deletions SPFCN/model/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def __init__(self, device_id: int, **kwargs):
self.network.load_state_dict(torch.load(self.config['parameter_path'], map_location=self.device))
except RuntimeError:
net_path = self.config['parameter_path'].replace('.pkl', '.pt')
network = torch.load(net_path, map_location=self.device)
self.network= dill.loads(network)
network = torch.load(self.config['parameter_path'], map_location=self.device)
self.network = dill.loads(network)
self.network.eval()

def update_config(self, **kwargs):
Expand Down Expand Up @@ -82,5 +82,5 @@ def __call__(self, bev_image):
(mark_map[i, 1] + delta_x, mark_map[i, 0] + delta_y)))
break

print(f'mark : {mark.shape} / mark_prediction : {mark_prediction.shape} / mark_map : {mark_map.shape} / direction : {direction.shape} / item : {item.shape} / distance_map : {distance_map.shape} / slot_list : {slot_list.shape}')
print(f'slot_list : {slot_list}')
return slot_list
26 changes: 13 additions & 13 deletions SPFCN/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@


@torch.no_grad()
def auto_test(dataset,
network,
device_id: int = 0,
load_path: str = None):
def auto_test(dataset,
network,
device_id: int = 0,
load_path: str = None):
device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)
try:
assert os.path.exists(net_path)
network.load_state_dict(torch.load(net_path, map_location=device))
except RuntimeError:
net_path = load_path.replace('pkl', '.pt')

try:
assert os.path.exists(load_path)
network.load_state_dict(torch.load(load_path, map_location=device))
except RuntimeError:
net_path = load_path.replace('pkl', 'pt')
assert os.path.exists(net_path)
network = torch.load(net_path, map_location=device)
network= dill.loads(network)
network = torch.load(net_path, map_location=device)
network = dill.loads(network)

network.eval()

auto_tester = Tester(dataset, network, device)
Expand Down
8 changes: 4 additions & 4 deletions SPFCN/test/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def step(self):
print("ZeroDivisionError at slot_gt_count")

print("\rIndex: {}, Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(index, mark_precision, mark_recall, slot_precision, slot_recall), end='')
print('\r' + ' ' * 50, end="")
print("Total score - Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(index, mark_precision, mark_recall, slot_precision, slot_recall))
# print('\r' + ' ' * 50, end="")
print("Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(mark_precision, mark_recall, slot_precision, slot_recall))

def get_network_inference_time(self):
Expand Down Expand Up @@ -166,5 +166,5 @@ def get_inference_time(self, foo):
testing_image, _ = self.dataset.next()
index += 1
print("\rIndex: {}, Inference Time: {:.1f}ms".format(index, 1e3 * time_step / index), end="")
print('\r' + ' ' * 40, end="")
# print('\r' + ' ' * 40, end="")
return "Inference Time: {:.1f}ms".format(1e3 * time_step / index)
168 changes: 0 additions & 168 deletions SPFCN/train/validator

This file was deleted.

1 change: 0 additions & 1 deletion SPFCN_Light/slot_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(self, file_name, device_id=0):
self.device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)
self._network = Hourglass([1, 40, 56, 55, 60, 59, 61, 59, 60, 55, 56, 40, 3]).to(self.device)
self._network.load_state_dict(torch.load(file_name), strict=True)
print("Success load file {}.".format(file_name))
self._network.eval()

self.temp_h = torch.ones((1, 224)).to(self.device)
Expand Down
Binary file added light_result.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 5 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@

### ORIGINAL VERSION ###
# Train model
# data_num=6500, batch_size=32, valid_data_num=1000, valid_batch_size=32,
# data_num=1000, batch_size=32, valid_data_num=100, valid_batch_size=32
slot_network_training(data_num=1000, batch_size=32, valid_data_num=100, valid_batch_size=32, epoch=80, input_res=224, device_id=0, num_workers=1)
slot_network_training(data_num=6500, batch_size=50, valid_data_num=1500, valid_batch_size=32, epoch=100, input_res=224, device_id=0, num_workers=4)

# Test model
params_path = './parameters/merge_bn_epoch10_loss4.pkl'
slot_network_testing(parameter_path=params_path, data_num=1500, batch_size=50, input_res=224, device_id=0, num_workers=1)
params_path = './parameters/merge_bn_epoch80_loss4.pkl'
slot_network_testing(parameter_path=params_path, data_num=1500, batch_size=50, input_res=224, device_id=0)

# Load detector
detector = SlotDetector(device_id=0, dim_encoder=[32, 44, 64, 92, 128], parameter_path=params_path)
Expand All @@ -37,7 +35,7 @@
cv2.line(current_frame, pt0, pt3, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt1, pt2, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2)
cv2.imwrite("result.jpg", current_frame)
cv2.imwrite("original_result.jpg", current_frame)


### LIGHT VERSION ###
Expand All @@ -63,4 +61,4 @@
cv2.line(current_frame, pt1, pt3, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2)
cv2.putText(current_frame, "%.2f fps" % infer_fps, (30, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255))
cv2.imwrite("result.jpg", current_frame)
cv2.imwrite("light_result.jpg", current_frame)
Binary file added original_result.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 70c987d

Please sign in to comment.