Skip to content

Commit

Permalink
add img-seq type input
Browse files Browse the repository at this point in the history
  • Loading branch information
yamy-cheng committed Apr 25, 2023
1 parent e4b61d4 commit a1f208e
Showing 1 changed file with 71 additions and 48 deletions.
119 changes: 71 additions & 48 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,26 @@ def get_meta_from_video(input_video):
_, first_frame = cap.read()
cap.release()

first_frame = cv2.cvtColor(first_frame,cv2.COLOR_BGR2RGB)
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)

return first_frame, first_frame, first_frame

def get_meta_from_img_seq(input_img_seq):
print("get meta information of img seq")
import pdb; pdb.set_trace()
# Create dir
file_name = os.path.dirname(input_img_seq)
file_path = f'./assets/{file_name}'
if os.path.isfile(file_path):
os.system(f'rm -r {file_path}')
os.makedirs(file_path)
# Unzip file
os.system(f'unzip {input_img_seq} -d {file_path}')

imgs_path = [os.path.join(file_path, img_name) for img_name in os.listdir(file_path)]
first_frame = imgs_path[0]
first_frame = cv2.imread(first_frame)
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)

return first_frame, first_frame, first_frame

Expand Down Expand Up @@ -205,7 +224,6 @@ def seg_track_app():
##########################################################
###################### Front-end ########################
##########################################################

app = gr.Blocks()

with app:
Expand Down Expand Up @@ -233,12 +251,16 @@ def seg_track_app():
with gr.Row():
# video input
with gr.Column(scale=0.5):
input_video = gr.Video(label='Input video').style(height=550)
# listen to the user action for play and pause input video
input_video.play(fn=play_video, inputs=play_state, outputs=play_state, scroll_to_output=True, show_progress=True)
input_video.pause(fn=pause_video, inputs=play_state, outputs=play_state)
with gr.Tab(label="Video type input"):
input_video = gr.Video(label='Input video').style(height=550)
# listen to the user action for play and pause input video
input_video.play(fn=play_video, inputs=play_state, outputs=play_state, scroll_to_output=True, show_progress=True)
input_video.pause(fn=pause_video, inputs=play_state, outputs=play_state)

input_video_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)
with gr.Tab(label="Image-Seq type input"):
input_img_seq = gr.File(label='Input Image-Seq').style(height=550)

input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)


tab_everything = gr.Tab(label="Everything")
Expand Down Expand Up @@ -362,10 +384,22 @@ def seg_track_app():
input_video
],
outputs=[
input_video_first_frame, origin_frame, drawing_board
input_first_frame, origin_frame, drawing_board
]
)

# listen to the input_img_seq to get the first frame of video
input_img_seq.change(
fn=get_meta_from_img_seq,
inputs=[
input_img_seq
],
outputs=[
input_first_frame, origin_frame, drawing_board
]
)


# listen to the tab to init SegTracker
tab_everything.select(
fn=init_SegTracker,
Expand All @@ -377,7 +411,7 @@ def seg_track_app():
origin_frame
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state
Seg_Tracker, input_first_frame, click_state
],
queue=False,

Expand All @@ -393,7 +427,7 @@ def seg_track_app():
origin_frame
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state
Seg_Tracker, input_first_frame, click_state
],
queue=False,
)
Expand All @@ -408,26 +442,11 @@ def seg_track_app():
origin_frame,
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state, drawing_board
Seg_Tracker, input_first_frame, click_state, drawing_board
],
queue=False,
)

# Init Seg-Tracker
# initial_seg_tracker.click(
# fn=init_SegTracker,
# inputs=[
# aot_model,
# sam_gap,
# max_obj_num,
# points_per_side,
# origin_frame
# ],
# outputs=[
# Seg_Tracker, input_video_first_frame, click_state
# ]
# )

# Use SAM to segment everything for the first frame of video
seg_every_first_frame.click(
fn=segment_everything,
Expand All @@ -442,12 +461,12 @@ def seg_track_app():
],
outputs=[
Seg_Tracker,
input_video_first_frame,
input_first_frame,
],
)

# Interactively modify the mask acc click
input_video_first_frame.select(
input_first_frame.select(
fn=sam_refine,
inputs=[
Seg_Tracker, origin_frame, point_prompt, click_state,
Expand All @@ -457,7 +476,7 @@ def seg_track_app():
points_per_side,
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state
Seg_Tracker, input_first_frame, click_state
]
)

Expand All @@ -472,7 +491,7 @@ def seg_track_app():
points_per_side,
],
outputs=[
Seg_Tracker, input_video_first_frame, drawing_board
Seg_Tracker, input_first_frame, drawing_board
]
)

Expand Down Expand Up @@ -503,7 +522,7 @@ def seg_track_app():
origin_frame
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state
Seg_Tracker, input_first_frame, click_state
],
queue=False,
show_progress=False
Expand All @@ -519,7 +538,7 @@ def seg_track_app():
origin_frame
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state
Seg_Tracker, input_first_frame, click_state
],
queue=False,
show_progress=False
Expand All @@ -535,7 +554,7 @@ def seg_track_app():
origin_frame,
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state, drawing_board
Seg_Tracker, input_first_frame, click_state, drawing_board
],
queue=False,
show_progress=False
Expand All @@ -552,7 +571,7 @@ def seg_track_app():
points_per_side,
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state
Seg_Tracker, input_first_frame, click_state
]
)

Expand All @@ -566,22 +585,26 @@ def seg_track_app():
points_per_side,
],
outputs=[
Seg_Tracker, input_video_first_frame, click_state
Seg_Tracker, input_first_frame, click_state
]
)

gr.Examples(
examples=[
# os.path.join(os.path.dirname(__file__), "assets", "840_iSXIa0hE8Ek.mp4"),
os.path.join(os.path.dirname(__file__), "assets", "blackswan.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "Resized_cxk.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "bear.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "camel.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "skate-park.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "swing.mp4"),
],
inputs=[input_video],
)

with gr.Tab(label='Video example'):
gr.Examples(
examples=[
# os.path.join(os.path.dirname(__file__), "assets", "840_iSXIa0hE8Ek.mp4"),
os.path.join(os.path.dirname(__file__), "assets", "blackswan.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "Resized_cxk.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "bear.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "camel.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "skate-park.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "swing.mp4"),
],
inputs=[input_video],
)

with gr.Tab(label='Image seq expamle'):
pass

app.queue(concurrency_count=1)
app.launch(debug=True, enable_queue=True, share=True)
Expand Down

0 comments on commit a1f208e

Please sign in to comment.