Skip to content

Commit 615fb0e

Browse files
authored
SAM2 AMG cli.py on modal (#1349)
1 parent 6ff3904 commit 615fb0e

File tree

3 files changed

+136
-29
lines changed

3 files changed

+136
-29
lines changed

examples/sam2_amg_server/cli.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from server import show_anns
66
from server import model_type_to_paths
77
from server import MODEL_TYPES_TO_MODEL
8+
from server import set_fast
9+
from server import set_furious
810
from torchao._models.sam2.build_sam import build_sam2
911
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
1012
from torchao._models.sam2.utils.amg import rle_to_mask
@@ -19,13 +21,17 @@ def main_docstring():
1921
output_path (str): Path to output image
2022
"""
2123

22-
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False):
24+
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False):
2325
device = "cuda"
2426
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
2527
if verbose:
2628
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
2729
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
2830
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
31+
if fast:
32+
set_fast(mask_generator)
33+
if furious:
34+
set_furious(mask_generator)
2935
image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read()))
3036
if verbose:
3137
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from pathlib import Path
2+
3+
import modal
4+
5+
app = modal.App("torchao-sam-2-cli")
6+
7+
TARGET = "/root/"
8+
9+
image = (
10+
modal.Image.debian_slim(python_version="3.12.7")
11+
.pip_install("numpy<3", "tqdm")
12+
.pip_install(
13+
"torch",
14+
pre=True,
15+
index_url="https://download.pytorch.org/whl/nightly/cu124", # tested with torch-2.6.0.dev20241120
16+
)
17+
.pip_install(
18+
"torchvision",
19+
pre=True,
20+
index_url="https://download.pytorch.org/whl/nightly/cu124", # tested with torch-2.6.0.dev20241120
21+
)
22+
.apt_install("git")
23+
.apt_install("libopencv-dev")
24+
.apt_install("python3-opencv")
25+
.run_commands(["git clone https://github.com/pytorch/ao.git /tmp/ao_src"])
26+
.run_commands(["cd /tmp/ao_src; python setup.py develop"])
27+
.pip_install(
28+
"gitpython",
29+
)
30+
.apt_install("wget")
31+
.run_commands([f"wget https://raw.githubusercontent.com/pytorch/ao/refs/heads/main/examples/sam2_amg_server/requirements.txt"])
32+
.pip_install_from_requirements(
33+
'requirements.txt',
34+
)
35+
)
36+
37+
checkpoints = modal.Volume.from_name("checkpoints", create_if_missing=True)
38+
39+
@app.function(
40+
image=image,
41+
gpu="H100",
42+
volumes={
43+
TARGET + "checkpoints": checkpoints,
44+
# # mount the caches of torch.compile and friends
45+
# "/root/.nv": modal.Volume.from_name("torchao-sam-2-cli-nv-cache", create_if_missing=True),
46+
# "/root/.triton": modal.Volume.from_name(
47+
# "torchao-sam-2-cli-triton-cache", create_if_missing=True
48+
# ),
49+
# "/root/.inductor-cache": modal.Volume.from_name(
50+
# "torchao-sam-2-cli-inductor-cache", create_if_missing=True
51+
# ),
52+
},
53+
timeout=60 * 60,
54+
)
55+
def eval(input_bytes, fast, furious):
56+
import torch
57+
import torchao
58+
import os
59+
60+
import subprocess
61+
from pathlib import Path
62+
from git import Repo
63+
64+
def download_file(url, filename):
65+
command = f"wget -O {filename} {url}"
66+
subprocess.run(command, shell=True, check=True)
67+
68+
os.chdir(Path(TARGET))
69+
download_file("https://raw.githubusercontent.com/pytorch/ao/refs/heads/climodal1/examples/sam2_amg_server/cli.py", "cli.py")
70+
download_file("https://raw.githubusercontent.com/pytorch/ao/refs/heads/climodal1/examples/sam2_amg_server/server.py", "server.py")
71+
# Create a Path object for the current directory
72+
current_directory = Path('.')
73+
74+
with open('/tmp/dog.jpg', 'wb') as file:
75+
file.write(input_bytes)
76+
77+
import sys
78+
sys.path.append(".")
79+
from cli import main as cli_main
80+
cli_main(Path(TARGET) / Path("checkpoints"),
81+
model_type="large",
82+
input_path="/tmp/dog.jpg",
83+
output_path="/tmp/dog_masked_2.png",
84+
verbose=True,
85+
fast=fast,
86+
furious=furious)
87+
88+
return bytearray(open('/tmp/dog_masked_2.png', 'rb').read())
89+
90+
@app.local_entrypoint()
91+
def main(input_path, output_path, fast=False, furious=False):
92+
bytes = eval.remote(open(input_path, 'rb').read(), fast, furious)
93+
with open(output_path, "wb") as file:
94+
file.write(bytes)

examples/sam2_amg_server/server.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,39 @@ def model_type_to_paths(checkpoint_path, model_type):
332332
model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}"
333333
return sam2_checkpoint, model_cfg
334334

335+
def set_fast(mask_generator):
336+
# TODO: Using CUDA graphs can cause numerical differences?
337+
mask_generator.predictor.model.image_encoder = torch.compile(
338+
mask_generator.predictor.model.image_encoder,
339+
mode="max-autotune",
340+
fullgraph=True,
341+
dynamic=False,
342+
)
343+
344+
mask_generator.predictor._predict_masks = torch.compile(
345+
mask_generator.predictor._predict_masks,
346+
mode="max-autotune",
347+
fullgraph=True,
348+
dynamic=False,
349+
)
350+
351+
# mask_generator.predictor._predict_masks_postprocess = torch.compile(
352+
# mask_generator.predictor._predict_masks_postprocess,
353+
# fullgraph=True,
354+
# dynamic=True,
355+
# )
356+
357+
358+
def set_furious(mask_generator):
359+
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16)
360+
# NOTE: Not baseline feature
361+
mask_generator.predictor._image_dtype = torch.float16
362+
mask_generator.predictor._transforms_device = mask_generator.predictor.device
363+
torch.set_float32_matmul_precision('high')
364+
mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16)
365+
# NOTE: Not baseline feature
366+
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16
367+
335368

336369
def main(checkpoint_path,
337370
model_type,
@@ -378,36 +411,10 @@ def main(checkpoint_path,
378411

379412
if fast:
380413
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
381-
# TODO: Using CUDA graphs can cause numerical differences?
382-
mask_generator.predictor.model.image_encoder = torch.compile(
383-
mask_generator.predictor.model.image_encoder,
384-
mode="max-autotune",
385-
fullgraph=True,
386-
dynamic=False,
387-
)
388-
389-
mask_generator.predictor._predict_masks = torch.compile(
390-
mask_generator.predictor._predict_masks,
391-
mode="max-autotune",
392-
fullgraph=True,
393-
dynamic=False,
394-
)
395-
396-
# mask_generator.predictor._predict_masks_postprocess = torch.compile(
397-
# mask_generator.predictor._predict_masks_postprocess,
398-
# fullgraph=True,
399-
# dynamic=True,
400-
# )
414+
set_fast(mask_generator)
401415

402416
if furious:
403-
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16)
404-
# NOTE: Not baseline feature
405-
mask_generator.predictor._image_dtype = torch.float16
406-
mask_generator.predictor._transforms_device = mask_generator.predictor.device
407-
torch.set_float32_matmul_precision('high')
408-
mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16)
409-
# NOTE: Not baseline feature
410-
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16
417+
set_furious(mask_generator)
411418

412419
with open('dog.jpg', 'rb') as f:
413420
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))

0 commit comments

Comments
 (0)