Skip to content

Commit

Permalink
Update grasp_detect_multibox.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kmittle authored Jun 29, 2023
1 parent 174f60a commit 99ba83c
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions grasp_detect_multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,34 @@ def forward(self, input, confidence_threshold):
cx.unsqueeze(1), cy.unsqueeze(1), w.unsqueeze(1), h.unsqueeze(1), theta.unsqueeze(1)], dim=1)


# def draw_multi_box(img, box_coordinates):
# for i in range(box_coordinates.shape[0]):
# center = (box_coordinates[i, 1].item(), box_coordinates[i, 2].item())
# size = (box_coordinates[i, 3].item(), box_coordinates[i, 4].item())
# angle = box_coordinates[i, 5].item()
# box = cv2.boxPoints((center, size, angle))
# box = np.int64(box)
# cv2.drawContours(img, [box], -1, (0, 255, 0), 2)
# cv2.imshow("Image", img)
# cv2.waitKey(0)
# cv2.destroyAllWindows()


def draw_multi_box(img, box_coordinates):
point_color1 = (255, 255, 0) # BGR
point_color2 = (255, 0, 255) # BGR
thickness = 2
lineType = 4
for i in range(box_coordinates.shape[0]):
center = (box_coordinates[i, 1].item(), box_coordinates[i, 2].item())
size = (box_coordinates[i, 3].item(), box_coordinates[i, 4].item())
angle = box_coordinates[i, 5].item()
box = cv2.boxPoints((center, size, angle))
box = np.int64(box)
cv2.drawContours(img, [box], -1, (0, 255, 0), 2)
cv2.line(img, box[0], box[3], point_color1, thickness, lineType)
cv2.line(img, box[3], box[2], point_color2, thickness, lineType)
cv2.line(img, box[2], box[1], point_color1, thickness, lineType)
cv2.line(img, box[1], box[0], point_color2, thickness, lineType)
cv2.imshow("Image", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
Expand All @@ -71,21 +91,21 @@ def draw_multi_box(img, box_coordinates):

weights_path = 'weights/epoch6_loss_8.045684943666645.pth'

img = cv2.imread(r'J:\experiment_data\MOS\img\000000r.png')
img = cv2.imread(r'J:\experiment_data\MOS\img\002243r.png')

transform = transforms.Compose([
transforms.ToTensor(),
])

inference_multi_image = DetectMultiImage(device=device, weights_path=weights_path)

img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转BGR格式为RGB格式
img2 = transform(img2).unsqueeze(dim=0).to(device)

boxes = inference_multi_image(img2, 0.9999)

print(boxes.shape)
print(boxes[:, 0].data[:5])

draw_multi_box(img, boxes.data)
draw_multi_box(img, boxes.data) # 此处传入的img是OpenCV的BGR格式的

0 comments on commit 99ba83c

Please sign in to comment.