diff --git a/SPFCN_Light/slot_detector.py b/SPFCN_Light/slot_detector.py index c9044a2..518323f 100644 --- a/SPFCN_Light/slot_detector.py +++ b/SPFCN_Light/slot_detector.py @@ -4,11 +4,11 @@ Example: model_parameter_path = "..." - detector = Detector(model_parameter_path) + detector = Detector(model_parameter_path, device_id) inference_result = detector(inference_image) Data form: - inference_image: Gray Image, torch.Tensor with Size([224, 224]) + inference_image: Gray Image, numpy.array with Size([224, 224]) inference_result: list of slot_points Requirements: @@ -16,7 +16,7 @@ numpy == 1.16.5 opencv == 3.4.2 - 2019.11.28 + 2020.09.15 """ from cv2 import line @@ -68,106 +68,98 @@ def forward(self, feature): class Detector(object): - def __init__(self, file_name): + def __init__(self, file_name, device_id=0): super().__init__() - self._network = Hourglass([1, 40, 56, 55, 60, 59, 61, 59, 60, 55, 56, 40, 3]).cuda() + self.device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id) + self._network = Hourglass([1, 40, 56, 55, 60, 59, 61, 59, 60, 55, 56, 40, 3]).to(self.device) self._network.load_state_dict(torch.load(file_name), strict=True) print("Success load file {}.".format(file_name)) self._network.eval() - @staticmethod - def _mask_detection(output, threshold=0): - temp_h = torch.ones((1, 224)).cuda() - temp_w = torch.ones((224, 1)).cuda() - - left = torch.cat((output[1:, :], temp_h), dim=0) - right = torch.cat((temp_h, output[:-1, :]), dim=0) - up = torch.cat((output[:, 1:], temp_w), dim=1) - down = torch.cat((temp_w, output[:, :-1]), dim=1) - - mask = (output > threshold) * (output > left) * (output > right) * (output > up) * (output > down) - return mask.float(), torch.where(mask), mask.int().sum().item() - - def __call__(self, inference_image, threshold=0.1): - inference_image = ((inference_image - torch.mean(inference_image)) - / torch.std(inference_image)).unsqueeze_(dim=0).unsqueeze_(dim=0) - with torch.no_grad(): - outputs = self._network(inference_image) - mask, mask_points, mask_point_count = self._mask_detection(outputs[0, 0]) - entry = outputs[0, 1].sigmoid_() - side = outputs[0, 2].sigmoid_() - - result = [] - for mask_point_index1 in range(mask_point_count): - x1 = mask_points[0][mask_point_index1].item() - y1 = mask_points[1][mask_point_index1].item() - if x1 < 16 or 208 < x1 or y1 < 16 or 208 < y1: + self.temp_h = torch.ones((1, 224)).to(self.device) + self.temp_w = torch.ones((224, 1)).to(self.device) + + @torch.no_grad() + def __call__(self, inference_image): + inference_image = torch.from_numpy(inference_image).to(self.device).float() + inference_image = ((inference_image - torch.mean(inference_image)) / torch.std(inference_image)) + + outputs = self._network(inference_image.unsqueeze_(dim=0).unsqueeze_(dim=0))[0] + + output = outputs[0] + left = torch.cat((output[1:, :], self.temp_h), dim=0) + right = torch.cat((self.temp_h, output[:-1, :]), dim=0) + up = torch.cat((output[:, 1:], self.temp_w), dim=1) + down = torch.cat((self.temp_w, output[:, :-1]), dim=1) + mask = (output > 0) * (output > left) * (output > right) * (output > up) * (output > down) + mask, mask_points, mask_point_count = mask.float(), torch.where(mask), mask.int().sum().item() + + entry = outputs[1].sigmoid_() + side = outputs[2].sigmoid_() + + result = [] + for mask_point_index1 in range(mask_point_count - 1): + x1 = mask_points[0][mask_point_index1].item() + y1 = mask_points[1][mask_point_index1].item() + if x1 < 16 or 208 < x1 or y1 < 16 or 208 < y1: + continue + + for mask_point_index2 in range(mask_point_index1 + 1, mask_point_count): + x2 = mask_points[0][mask_point_index2].item() + y2 = mask_points[1][mask_point_index2].item() + if x2 < 16 or 208 < x2 or y2 < 16 or 208 < y2: continue - for mask_point_index2 in range(mask_point_index1 + 1, mask_point_count): - x2 = mask_points[0][mask_point_index2].item() - y2 = mask_points[1][mask_point_index2].item() - - if x2 < 16 or 208 < x2 or y2 < 16 or 208 < y2: - continue - - distance = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) - # print("distance", distance) - if distance < 56 or 72 < distance < 128 or 144 < distance: - continue - - if y1 < y2: - y_min = y1 - 1 - y_max = y2 + 2 - canvas = np.zeros([x2 - x1 + 3, y2 - y1 + 3]) - line(canvas, (1, 1), (y2 - y1 + 1, x2 - x1 + 1), 1, 2) - else: - y_min = y2 - 1 - y_max = y1 + 2 - canvas = np.zeros([x2 - x1 + 3, y1 - y2 + 3]) - line(canvas, (1, x2 - x1 + 1), (y1 - y2 + 1, 1), 1, 2) - canvas = torch.from_numpy(canvas).float().cuda() - if canvas.shape != torch.Size([x2 - x1 + 3, y_max - y_min]): - continue - - if (canvas * mask[x1 - 1:x2 + 2, y_min:y_max]).sum() != 2: - continue - - score_entry = (canvas * entry[x1 - 1:x2 + 2, y_min:y_max]).sum() / distance - # print("score_entry", score_entry.item()) - if score_entry < threshold: - continue - - direct = (side[x1 - 2:x1 + 3, y1 - 2:y1]).sum() > (side[x1 - 2:x1 + 3, y1:y1 + 3]).sum() - if direct: - direct_vector = (16 * (x1 - x2) / distance, 16 * (y2 - y1) / distance) - else: - direct_vector = (16 * (x2 - x1) / distance, 16 * (y1 - y2) / distance) - - vec_x = int(direct_vector[1]) - vec_y = int(direct_vector[0]) - side_score1 = side[x1 + vec_x - 2:x1 + vec_x + 3, y1 + vec_y - 2:y1 + vec_y + 3].sum() - side_score2 = side[x2 + vec_x - 2:x2 + vec_x + 3, y2 + vec_y - 2:y2 + vec_y + 3].sum() - # print("side score", side_score1, side_score2) - if side_score1 < threshold or side_score2 < threshold: - continue - - # print("score_final", score_entry * (side_score1 + side_score2)) - if score_entry * (side_score1 + side_score2) < threshold * 8: - continue - - pt0 = (y1, x1) - pt1 = (y2, x2) - - if distance < 85: - pt2 = (pt0[0] + direct_vector[0] * 6, pt0[1] + direct_vector[1] * 6) - pt3 = (pt1[0] + direct_vector[0] * 6, pt1[1] + direct_vector[1] * 6) - else: - pt2 = (pt0[0] + direct_vector[0] * 3, pt0[1] + direct_vector[1] * 3) - pt3 = (pt1[0] + direct_vector[0] * 3, pt1[1] + direct_vector[1] * 3) - - if direct: - result.append((pt0, pt1, pt2, pt3)) - else: - result.append((pt1, pt0, pt3, pt2)) + distance = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + if distance < 56 or 72 < distance < 128 or 144 < distance: + continue + + if y1 < y2: + y_min = y1 - 1 + y_max = y2 + 2 + canvas = np.zeros([x2 - x1 + 3, y2 - y1 + 3]) + line(canvas, (1, 1), (y2 - y1 + 1, x2 - x1 + 1), 1, thickness=1) + else: + y_min = y2 - 1 + y_max = y1 + 2 + canvas = np.zeros([x2 - x1 + 3, y1 - y2 + 3]) + line(canvas, (1, x2 - x1 + 1), (y1 - y2 + 1, 1), 1, thickness=1) + canvas = torch.from_numpy(canvas).float().to(self.device) + if (canvas * mask[x1 - 1:x2 + 2, y_min:y_max]).sum() != 2: + continue + + score_entry = (canvas * entry[x1 - 1:x2 + 2, y_min:y_max]).sum() / distance + if score_entry < 0.25: + continue + + direct = (side[x1 - 2:x1 + 3, y1 - 2:y1]).sum() > (side[x1 - 2:x1 + 3, y1:y1 + 3]).sum() + if direct: + direct_vector_y = 16 * (x1 - x2) / distance + direct_vector_x = -16 * (y1 - y2) / distance + else: + direct_vector_y = -16 * (x1 - x2) / distance + direct_vector_x = 16 * (y1 - y2) / distance + + vec_x = int(direct_vector_x) + vec_y = int(direct_vector_y) + side_score1 = side[x1 + vec_x - 2:x1 + vec_x + 3, y1 + vec_y - 2:y1 + vec_y + 3].sum() + if side_score1 < 0.25: + continue + + side_score2 = side[x2 + vec_x - 2:x2 + vec_x + 3, y2 + vec_y - 2:y2 + vec_y + 3].sum() + if side_score2 < 0.25: + continue + + if score_entry * (side_score1 + side_score2) < 1: + continue + + pt0 = (y1, x1) + pt1 = (y2, x2) + pt2 = (y1 + direct_vector_y * 6, x1 + direct_vector_x * 6) + pt3 = (y2 + direct_vector_y * 6, x2 + direct_vector_x * 6) + + if direct: + result.append((pt0, pt1, pt2, pt3)) + else: + result.append((pt1, pt0, pt3, pt2)) return result diff --git a/main.py b/main.py index a313a17..f6ce1a9 100644 --- a/main.py +++ b/main.py @@ -1,30 +1,29 @@ import time import cv2 -import torch from SPFCN_Light.slot_detector import Detector if __name__ == "__main__": # Read image current_frame = cv2.imread("demo.jpg") - resolution = current_frame.shape[0] # Initial model - detector = Detector("SPFCN_Light/stable_parameter_0914.pkl") + detector = Detector("./SPFCN_Light/stable_parameter_0914.pkl", device_id=-1) # Start the detection for frame_index in range(1000): # Get the result tic = time.time() - inference_image = torch.from_numpy(cv2.cvtColor(cv2.resize(current_frame, (224, 224)), cv2.COLOR_BGR2GRAY)) - inference_result = detector(inference_image.float().cuda(), threshold=0.1) + inference_image = cv2.cvtColor(cv2.resize(current_frame, (224, 224)), cv2.COLOR_BGR2GRAY) + inference_result = detector(inference_image) toc = time.time() time_span = toc - tic infer_fps = 1 / (time_span + 1e-5) print("Frame:{:d}, Time used:{:.3f}, FPS:{:.3f}".format(frame_index, time_span * 1000, infer_fps), end='\r') # Visualize the merge image with result + resolution = current_frame.shape[0] for detect_result in inference_result: pt0 = (int(detect_result[0][0] * resolution / 224), int(detect_result[0][1] * resolution / 224)) pt1 = (int(detect_result[1][0] * resolution / 224), int(detect_result[1][1] * resolution / 224)) @@ -35,6 +34,5 @@ cv2.line(current_frame, pt1, pt3, (0, 0, 255), thickness=2) cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2) cv2.putText(current_frame, "%.2f fps" % infer_fps, (30, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255)) - cv2.imshow("result", current_frame) - cv2.waitKey(0) - cv2.destroyAllWindows() + + cv2.imwrite("result.jpg", current_frame)