Skip to content

Commit

Permalink
Speed up SPFCN Light to ~40FPS in CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
LoyalBlanc committed Sep 15, 2020
1 parent e846fea commit fde82cb
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 106 deletions.
188 changes: 90 additions & 98 deletions SPFCN_Light/slot_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
Example:
model_parameter_path = "..."
detector = Detector(model_parameter_path)
detector = Detector(model_parameter_path, device_id)
inference_result = detector(inference_image)
Data form:
inference_image: Gray Image, torch.Tensor with Size([224, 224])
inference_image: Gray Image, numpy.array with Size([224, 224])
inference_result: list of slot_points
Requirements:
pytorch == 1.2.0
numpy == 1.16.5
opencv == 3.4.2
2019.11.28
2020.09.15
"""

from cv2 import line
Expand Down Expand Up @@ -68,106 +68,98 @@ def forward(self, feature):


class Detector(object):
def __init__(self, file_name):
def __init__(self, file_name, device_id=0):
super().__init__()
self._network = Hourglass([1, 40, 56, 55, 60, 59, 61, 59, 60, 55, 56, 40, 3]).cuda()
self.device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)
self._network = Hourglass([1, 40, 56, 55, 60, 59, 61, 59, 60, 55, 56, 40, 3]).to(self.device)
self._network.load_state_dict(torch.load(file_name), strict=True)
print("Success load file {}.".format(file_name))
self._network.eval()

@staticmethod
def _mask_detection(output, threshold=0):
temp_h = torch.ones((1, 224)).cuda()
temp_w = torch.ones((224, 1)).cuda()

left = torch.cat((output[1:, :], temp_h), dim=0)
right = torch.cat((temp_h, output[:-1, :]), dim=0)
up = torch.cat((output[:, 1:], temp_w), dim=1)
down = torch.cat((temp_w, output[:, :-1]), dim=1)

mask = (output > threshold) * (output > left) * (output > right) * (output > up) * (output > down)
return mask.float(), torch.where(mask), mask.int().sum().item()

def __call__(self, inference_image, threshold=0.1):
inference_image = ((inference_image - torch.mean(inference_image))
/ torch.std(inference_image)).unsqueeze_(dim=0).unsqueeze_(dim=0)
with torch.no_grad():
outputs = self._network(inference_image)
mask, mask_points, mask_point_count = self._mask_detection(outputs[0, 0])
entry = outputs[0, 1].sigmoid_()
side = outputs[0, 2].sigmoid_()

result = []
for mask_point_index1 in range(mask_point_count):
x1 = mask_points[0][mask_point_index1].item()
y1 = mask_points[1][mask_point_index1].item()
if x1 < 16 or 208 < x1 or y1 < 16 or 208 < y1:
self.temp_h = torch.ones((1, 224)).to(self.device)
self.temp_w = torch.ones((224, 1)).to(self.device)

@torch.no_grad()
def __call__(self, inference_image):
inference_image = torch.from_numpy(inference_image).to(self.device).float()
inference_image = ((inference_image - torch.mean(inference_image)) / torch.std(inference_image))

outputs = self._network(inference_image.unsqueeze_(dim=0).unsqueeze_(dim=0))[0]

output = outputs[0]
left = torch.cat((output[1:, :], self.temp_h), dim=0)
right = torch.cat((self.temp_h, output[:-1, :]), dim=0)
up = torch.cat((output[:, 1:], self.temp_w), dim=1)
down = torch.cat((self.temp_w, output[:, :-1]), dim=1)
mask = (output > 0) * (output > left) * (output > right) * (output > up) * (output > down)
mask, mask_points, mask_point_count = mask.float(), torch.where(mask), mask.int().sum().item()

entry = outputs[1].sigmoid_()
side = outputs[2].sigmoid_()

result = []
for mask_point_index1 in range(mask_point_count - 1):
x1 = mask_points[0][mask_point_index1].item()
y1 = mask_points[1][mask_point_index1].item()
if x1 < 16 or 208 < x1 or y1 < 16 or 208 < y1:
continue

for mask_point_index2 in range(mask_point_index1 + 1, mask_point_count):
x2 = mask_points[0][mask_point_index2].item()
y2 = mask_points[1][mask_point_index2].item()
if x2 < 16 or 208 < x2 or y2 < 16 or 208 < y2:
continue

for mask_point_index2 in range(mask_point_index1 + 1, mask_point_count):
x2 = mask_points[0][mask_point_index2].item()
y2 = mask_points[1][mask_point_index2].item()

if x2 < 16 or 208 < x2 or y2 < 16 or 208 < y2:
continue

distance = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
# print("distance", distance)
if distance < 56 or 72 < distance < 128 or 144 < distance:
continue

if y1 < y2:
y_min = y1 - 1
y_max = y2 + 2
canvas = np.zeros([x2 - x1 + 3, y2 - y1 + 3])
line(canvas, (1, 1), (y2 - y1 + 1, x2 - x1 + 1), 1, 2)
else:
y_min = y2 - 1
y_max = y1 + 2
canvas = np.zeros([x2 - x1 + 3, y1 - y2 + 3])
line(canvas, (1, x2 - x1 + 1), (y1 - y2 + 1, 1), 1, 2)
canvas = torch.from_numpy(canvas).float().cuda()
if canvas.shape != torch.Size([x2 - x1 + 3, y_max - y_min]):
continue

if (canvas * mask[x1 - 1:x2 + 2, y_min:y_max]).sum() != 2:
continue

score_entry = (canvas * entry[x1 - 1:x2 + 2, y_min:y_max]).sum() / distance
# print("score_entry", score_entry.item())
if score_entry < threshold:
continue

direct = (side[x1 - 2:x1 + 3, y1 - 2:y1]).sum() > (side[x1 - 2:x1 + 3, y1:y1 + 3]).sum()
if direct:
direct_vector = (16 * (x1 - x2) / distance, 16 * (y2 - y1) / distance)
else:
direct_vector = (16 * (x2 - x1) / distance, 16 * (y1 - y2) / distance)

vec_x = int(direct_vector[1])
vec_y = int(direct_vector[0])
side_score1 = side[x1 + vec_x - 2:x1 + vec_x + 3, y1 + vec_y - 2:y1 + vec_y + 3].sum()
side_score2 = side[x2 + vec_x - 2:x2 + vec_x + 3, y2 + vec_y - 2:y2 + vec_y + 3].sum()
# print("side score", side_score1, side_score2)
if side_score1 < threshold or side_score2 < threshold:
continue

# print("score_final", score_entry * (side_score1 + side_score2))
if score_entry * (side_score1 + side_score2) < threshold * 8:
continue

pt0 = (y1, x1)
pt1 = (y2, x2)

if distance < 85:
pt2 = (pt0[0] + direct_vector[0] * 6, pt0[1] + direct_vector[1] * 6)
pt3 = (pt1[0] + direct_vector[0] * 6, pt1[1] + direct_vector[1] * 6)
else:
pt2 = (pt0[0] + direct_vector[0] * 3, pt0[1] + direct_vector[1] * 3)
pt3 = (pt1[0] + direct_vector[0] * 3, pt1[1] + direct_vector[1] * 3)

if direct:
result.append((pt0, pt1, pt2, pt3))
else:
result.append((pt1, pt0, pt3, pt2))
distance = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
if distance < 56 or 72 < distance < 128 or 144 < distance:
continue

if y1 < y2:
y_min = y1 - 1
y_max = y2 + 2
canvas = np.zeros([x2 - x1 + 3, y2 - y1 + 3])
line(canvas, (1, 1), (y2 - y1 + 1, x2 - x1 + 1), 1, thickness=1)
else:
y_min = y2 - 1
y_max = y1 + 2
canvas = np.zeros([x2 - x1 + 3, y1 - y2 + 3])
line(canvas, (1, x2 - x1 + 1), (y1 - y2 + 1, 1), 1, thickness=1)
canvas = torch.from_numpy(canvas).float().to(self.device)
if (canvas * mask[x1 - 1:x2 + 2, y_min:y_max]).sum() != 2:
continue

score_entry = (canvas * entry[x1 - 1:x2 + 2, y_min:y_max]).sum() / distance
if score_entry < 0.25:
continue

direct = (side[x1 - 2:x1 + 3, y1 - 2:y1]).sum() > (side[x1 - 2:x1 + 3, y1:y1 + 3]).sum()
if direct:
direct_vector_y = 16 * (x1 - x2) / distance
direct_vector_x = -16 * (y1 - y2) / distance
else:
direct_vector_y = -16 * (x1 - x2) / distance
direct_vector_x = 16 * (y1 - y2) / distance

vec_x = int(direct_vector_x)
vec_y = int(direct_vector_y)
side_score1 = side[x1 + vec_x - 2:x1 + vec_x + 3, y1 + vec_y - 2:y1 + vec_y + 3].sum()
if side_score1 < 0.25:
continue

side_score2 = side[x2 + vec_x - 2:x2 + vec_x + 3, y2 + vec_y - 2:y2 + vec_y + 3].sum()
if side_score2 < 0.25:
continue

if score_entry * (side_score1 + side_score2) < 1:
continue

pt0 = (y1, x1)
pt1 = (y2, x2)
pt2 = (y1 + direct_vector_y * 6, x1 + direct_vector_x * 6)
pt3 = (y2 + direct_vector_y * 6, x2 + direct_vector_x * 6)

if direct:
result.append((pt0, pt1, pt2, pt3))
else:
result.append((pt1, pt0, pt3, pt2))
return result
14 changes: 6 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
import time

import cv2
import torch

from SPFCN_Light.slot_detector import Detector

if __name__ == "__main__":
# Read image
current_frame = cv2.imread("demo.jpg")
resolution = current_frame.shape[0]

# Initial model
detector = Detector("SPFCN_Light/stable_parameter_0914.pkl")
detector = Detector("./SPFCN_Light/stable_parameter_0914.pkl", device_id=-1)

# Start the detection
for frame_index in range(1000):
# Get the result
tic = time.time()
inference_image = torch.from_numpy(cv2.cvtColor(cv2.resize(current_frame, (224, 224)), cv2.COLOR_BGR2GRAY))
inference_result = detector(inference_image.float().cuda(), threshold=0.1)
inference_image = cv2.cvtColor(cv2.resize(current_frame, (224, 224)), cv2.COLOR_BGR2GRAY)
inference_result = detector(inference_image)
toc = time.time()
time_span = toc - tic
infer_fps = 1 / (time_span + 1e-5)
print("Frame:{:d}, Time used:{:.3f}, FPS:{:.3f}".format(frame_index, time_span * 1000, infer_fps), end='\r')

# Visualize the merge image with result
resolution = current_frame.shape[0]
for detect_result in inference_result:
pt0 = (int(detect_result[0][0] * resolution / 224), int(detect_result[0][1] * resolution / 224))
pt1 = (int(detect_result[1][0] * resolution / 224), int(detect_result[1][1] * resolution / 224))
Expand All @@ -35,6 +34,5 @@
cv2.line(current_frame, pt1, pt3, (0, 0, 255), thickness=2)
cv2.line(current_frame, pt2, pt3, (0, 0, 255), thickness=2)
cv2.putText(current_frame, "%.2f fps" % infer_fps, (30, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255))
cv2.imshow("result", current_frame)
cv2.waitKey(0)
cv2.destroyAllWindows()

cv2.imwrite("result.jpg", current_frame)

0 comments on commit fde82cb

Please sign in to comment.