Skip to content

Inference on video without extracting images #641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

cjaverliat
Copy link

Inference on video without extracting images

This PR proposes an alternative to the original SAM2Base, SAM2Generic, which provides new APIs. Additionally, I added a SAM2GenericVideoPredictor which is a re-implementation of the video predictor but with configurable strategies for memorization and removal of past memories (cf. here for an example), this solves the issue with keeping everything in the VRAM.

More importantly, this allows to run the prediction on videos without having to decode the frames to jpeg files before-hand:

import cv2
import torch
from tqdm import tqdm
from sam2.sam2_generic_video_predictor import Prompt
from sam2.build_sam import build_sam2_generic_video_predictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"

predictor = build_sam2_generic_video_predictor(model_cfg, sam2_checkpoint, device=device)

cap = cv2.VideoCapture("./videos/bedroom.mp4")
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
orig_hw = (height, width)

def read_frame(cap) -> torch.Tensor:
    ret, frame = cap.read()
    if not ret:
        return None
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = torch.as_tensor(frame).permute(2, 0, 1).to(device) # HWC -> CHW
    frame = frame / 255.0
    return frame
 
# Add a prompt on the first frame
initial_frame = read_frame(cap)
points_coords = torch.tensor([400.0, 150.0], device=device).reshape((1, 1, 2))
points_labels = torch.tensor([1], device=device).reshape((1, 1))
prompt = Prompt(obj_id=0, points_coords=points_coords, points_labels=points_labels)
results = predictor.forward(frame=initial_frame, object_prompts=[prompt])

for f in tqdm(range(1, n_frames)):
    frame = read_frame(cap)

    if frame is None:
        break

    results = predictor.forward(frame=frame)
    
    # Do something with the result, for example:
    #     show_mask((results[0].best_mask_logits > 0), plt.gca(), obj_id=0)

The full usage example is available in the generic_video_predictor_example.ipynb notebook.

@cjaverliat cjaverliat marked this pull request as draft May 3, 2025 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants