Skip to content

Commit

Permalink
add interactive refine: brush
Browse files Browse the repository at this point in the history
  • Loading branch information
yamy-cheng committed Apr 21, 2023
1 parent d9161be commit dc52930
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 19 deletions.
26 changes: 21 additions & 5 deletions SegTracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(self,segtracker_args, sam_args,aot_args) -> None:
self.min_new_obj_iou = segtracker_args['min_new_obj_iou']
self.reference_objs_list = []
self.object_idx = 1
self.origin_merged_mask = None
self.refined_merged_mask = None
self.origin_merged_mask = None # init with 0 or segment-everthing
self.refined_merged_mask = None # interactively refine by user

# debug
self.everything_points = []
Expand Down Expand Up @@ -147,6 +147,25 @@ def find_new_objs(self, track_mask, seg_mask):
def restart_tracker(self):
self.tracker.restart()

def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray,):
''''
parameters:
origin_frame: H, W, C
bbox: [[x0, y0], [x1, y1]]
'''

# get interactive_mask
interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
self.refined_merged_mask = self.add_mask(interactive_mask)

# draw mask
masked_frame = draw_mask(origin_frame.copy(), self.refined_merged_mask)

# draw bbox
masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255))

return self.refined_merged_mask, masked_frame

def refine_first_frame_click(self, origin_frame: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
'''
it is used in first frame in video
Expand All @@ -155,8 +174,6 @@ def refine_first_frame_click(self, origin_frame: np.ndarray, points:np.ndarray,
# get interactive_mask
interactive_mask, logit, outline = self.sam.segment_with_click(origin_frame, points, labels, multimask)

# cv2.imwrite('./debug/interactive_mask.png', interactive_mask * 255)

self.refined_merged_mask = self.add_mask(interactive_mask)

# draw mask
Expand All @@ -174,7 +191,6 @@ def refine_first_frame_click(self, origin_frame: np.ndarray, points:np.ndarray,

return self.refined_merged_mask, masked_frame


def add_mask(self, interactive_mask, cover_origin_objects=True, single_object=True):
# if cover_origin_objects == Ture: interactive_mask will cover original object
# if single_object == True: added mask is belong to single object
Expand Down
34 changes: 21 additions & 13 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import argparse
import torch
import time
from seg_track_anything import seg_track_anything, aot_model2ckpt, colorize_mask, tracking_objects_in_video, draw_mask
from seg_track_anything import aot_model2ckpt, tracking_objects_in_video, draw_mask
import gc
import numpy as np
import json
from tool.transfer_tools import mask2bbox

def pause_video(play_state):
print("user pause_video")
Expand Down Expand Up @@ -174,9 +175,16 @@ def sam_brush(Seg_Tracker, origin_frame, brush_plane, aot_model, sam_gap, max_ob
Seg_Tracker, _ , _ = init_SegTracker(aot_model, sam_gap, max_obj_num, points_per_side, origin_frame)

mask = brush_plane["mask"]
import pdb; pdb.set_trace()
bbox = mask2bbox(mask[:, :, 0]) # bbox: [[x0, y0], [x1, y1]]
predicted_mask, masked_frame = Seg_Tracker.seg_acc_bbox(origin_frame, bbox)

return Seg_Tracker, origin_frame, origin_frame
with torch.cuda.amp.autocast():
# Reset the first frame's mask
frame_idx = 0
Seg_Tracker.restart_tracker()
Seg_Tracker.add_reference(origin_frame, predicted_mask, frame_idx)

return Seg_Tracker, masked_frame, origin_frame

def segment_everything(Seg_Tracker, aot_model, origin_frame, sam_gap, max_obj_num, points_per_side):

Expand Down Expand Up @@ -309,16 +317,6 @@ def seg_track_app():
interactive=True,
)

max_obj_num = gr.Slider(
label='max_obj_num',
minimum = 50,
step=1,
maximum = 300,
value=255,
interactive=True
)

with gr.Column(scale=0.5):
points_per_side = gr.Slider(
label = "points_per_side",
minimum= 1,
Expand All @@ -328,6 +326,8 @@ def seg_track_app():
interactive=True
)


with gr.Column(scale=0.5):
sam_gap = gr.Slider(
label='sam_gap',
minimum = 1,
Expand All @@ -337,6 +337,14 @@ def seg_track_app():
interactive=True,
)

max_obj_num = gr.Slider(
label='max_obj_num',
minimum = 50,
step=1,
maximum = 300,
value=255,
interactive=True
)
track_for_video = gr.Button(
value="Start Tracking",
interactive=True
Expand Down
Binary file modified assets/blackswan_seg.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 10 additions & 1 deletion tool/segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
contour_width = 5



class Segmentor:
def __init__(self, sam_args):
"""
Expand Down Expand Up @@ -134,7 +133,17 @@ def segment_with_click(self, origin_frame: np.ndarray, points:np.ndarray, labels
# painted_image = Image.fromarray(painted_image)
return mask.astype(np.uint8), logit, outline

def segment_with_box(self, origin_frame, bbox):
self.set_image(origin_frame)

masks , _, _ = self.interactive_predictor.predict(
point_coords=None,
point_labels=None,
box=np.array([[bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]]),
multimask_output=False
)

return masks

# if __name__ == "__main__":
# points = np.array([[500, 375], [1125, 625]])
Expand Down
25 changes: 25 additions & 0 deletions tool/transfer_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import cv2
import numpy as np

def mask2bbox(mask):
if len(np.where(mask > 0)[0]) == 0:
print(f'not mask')
return np.array([[0, 0], [0, 0]]).astype(np.int64)

x_ = np.sum(mask, axis=0)
y_ = np.sum(mask, axis=1)

x0 = np.min(np.nonzero(x_)[0])
x1 = np.max(np.nonzero(x_)[0])
y0 = np.min(np.nonzero(y_)[0])
y1 = np.max(np.nonzero(y_)[0])

return np.array([[x0, y0], [x1, y1]]).astype(np.int64)



if __name__ == '__main__':
mask = cv2.imread('./debug/painter_input_mask.jpg', -1)[2:, 2:]
bbox = mask2bbox(mask)
draw_0 = cv2.rectangle(mask, bbox[0], bbox[1], (0, 0, 255))
cv2.imwrite('./debug/rect.png', draw_0)

0 comments on commit dc52930

Please sign in to comment.