Skip to content

SAM2 AMG cli.py on modal #1349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion examples/sam2_amg_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from server import show_anns
from server import model_type_to_paths
from server import MODEL_TYPES_TO_MODEL
from server import set_fast
from server import set_furious
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from torchao._models.sam2.utils.amg import rle_to_mask
Expand All @@ -19,13 +21,17 @@ def main_docstring():
output_path (str): Path to output image
"""

def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False):
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False):
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
if verbose:
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
if fast:
set_fast(mask_generator)
if furious:
set_furious(mask_generator)
image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read()))
if verbose:
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
Expand Down
94 changes: 94 additions & 0 deletions examples/sam2_amg_server/cli_on_modal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from pathlib import Path

import modal

app = modal.App("torchao-sam-2-cli")

TARGET = "/root/"

image = (
modal.Image.debian_slim(python_version="3.12.7")
.pip_install("numpy<3", "tqdm")
.pip_install(
"torch",
pre=True,
index_url="https://download.pytorch.org/whl/nightly/cu124", # tested with torch-2.6.0.dev20241120
)
.pip_install(
"torchvision",
pre=True,
index_url="https://download.pytorch.org/whl/nightly/cu124", # tested with torch-2.6.0.dev20241120
)
.apt_install("git")
.apt_install("libopencv-dev")
.apt_install("python3-opencv")
.run_commands(["git clone https://github.com/pytorch/ao.git /tmp/ao_src"])
.run_commands(["cd /tmp/ao_src; python setup.py develop"])
.pip_install(
"gitpython",
)
.apt_install("wget")
.run_commands([f"wget https://raw.githubusercontent.com/pytorch/ao/refs/heads/main/examples/sam2_amg_server/requirements.txt"])
.pip_install_from_requirements(
'requirements.txt',
)
)

checkpoints = modal.Volume.from_name("checkpoints", create_if_missing=True)

@app.function(
image=image,
gpu="H100",
volumes={
TARGET + "checkpoints": checkpoints,
# # mount the caches of torch.compile and friends
# "/root/.nv": modal.Volume.from_name("torchao-sam-2-cli-nv-cache", create_if_missing=True),
# "/root/.triton": modal.Volume.from_name(
# "torchao-sam-2-cli-triton-cache", create_if_missing=True
# ),
# "/root/.inductor-cache": modal.Volume.from_name(
# "torchao-sam-2-cli-inductor-cache", create_if_missing=True
# ),
},
timeout=60 * 60,
)
def eval(input_bytes, fast, furious):
import torch
import torchao
import os

import subprocess
from pathlib import Path
from git import Repo

def download_file(url, filename):
command = f"wget -O {filename} {url}"
subprocess.run(command, shell=True, check=True)

os.chdir(Path(TARGET))
download_file("https://raw.githubusercontent.com/pytorch/ao/refs/heads/climodal1/examples/sam2_amg_server/cli.py", "cli.py")
download_file("https://raw.githubusercontent.com/pytorch/ao/refs/heads/climodal1/examples/sam2_amg_server/server.py", "server.py")
# Create a Path object for the current directory
current_directory = Path('.')

with open('/tmp/dog.jpg', 'wb') as file:
file.write(input_bytes)

import sys
sys.path.append(".")
from cli import main as cli_main
cli_main(Path(TARGET) / Path("checkpoints"),
model_type="large",
input_path="/tmp/dog.jpg",
output_path="/tmp/dog_masked_2.png",
verbose=True,
fast=fast,
furious=furious)

return bytearray(open('/tmp/dog_masked_2.png', 'rb').read())

@app.local_entrypoint()
def main(input_path, output_path, fast=False, furious=False):
bytes = eval.remote(open(input_path, 'rb').read(), fast, furious)
with open(output_path, "wb") as file:
file.write(bytes)
63 changes: 35 additions & 28 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,39 @@ def model_type_to_paths(checkpoint_path, model_type):
model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}"
return sam2_checkpoint, model_cfg

def set_fast(mask_generator):
# TODO: Using CUDA graphs can cause numerical differences?
mask_generator.predictor.model.image_encoder = torch.compile(
mask_generator.predictor.model.image_encoder,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

mask_generator.predictor._predict_masks = torch.compile(
mask_generator.predictor._predict_masks,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

# mask_generator.predictor._predict_masks_postprocess = torch.compile(
# mask_generator.predictor._predict_masks_postprocess,
# fullgraph=True,
# dynamic=True,
# )


def set_furious(mask_generator):
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16)
# NOTE: Not baseline feature
mask_generator.predictor._image_dtype = torch.float16
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')
mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16)
# NOTE: Not baseline feature
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16


def main(checkpoint_path,
model_type,
Expand Down Expand Up @@ -378,36 +411,10 @@ def main(checkpoint_path,

if fast:
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
# TODO: Using CUDA graphs can cause numerical differences?
mask_generator.predictor.model.image_encoder = torch.compile(
mask_generator.predictor.model.image_encoder,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

mask_generator.predictor._predict_masks = torch.compile(
mask_generator.predictor._predict_masks,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

# mask_generator.predictor._predict_masks_postprocess = torch.compile(
# mask_generator.predictor._predict_masks_postprocess,
# fullgraph=True,
# dynamic=True,
# )
set_fast(mask_generator)

if furious:
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16)
# NOTE: Not baseline feature
mask_generator.predictor._image_dtype = torch.float16
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')
mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16)
# NOTE: Not baseline feature
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16
set_furious(mask_generator)

with open('dog.jpg', 'rb') as f:
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))
Expand Down
Loading