|
| 1 | +from pathlib import Path |
| 2 | + |
| 3 | +import modal |
| 4 | + |
| 5 | +app = modal.App("torchao-sam-2-cli") |
| 6 | + |
| 7 | +TARGET = "/root/" |
| 8 | + |
| 9 | +image = ( |
| 10 | + modal.Image.debian_slim(python_version="3.12.7") |
| 11 | + .pip_install("numpy<3", "tqdm") |
| 12 | + .pip_install( |
| 13 | + "torch", |
| 14 | + pre=True, |
| 15 | + index_url="https://download.pytorch.org/whl/nightly/cu124", # tested with torch-2.6.0.dev20241120 |
| 16 | + ) |
| 17 | + .pip_install( |
| 18 | + "torchvision", |
| 19 | + pre=True, |
| 20 | + index_url="https://download.pytorch.org/whl/nightly/cu124", # tested with torch-2.6.0.dev20241120 |
| 21 | + ) |
| 22 | + .apt_install("git") |
| 23 | + .apt_install("libopencv-dev") |
| 24 | + .apt_install("python3-opencv") |
| 25 | + .run_commands(["git clone https://github.com/pytorch/ao.git /tmp/ao_src"]) |
| 26 | + .run_commands(["cd /tmp/ao_src; python setup.py develop"]) |
| 27 | + .pip_install( |
| 28 | + "gitpython", |
| 29 | + ) |
| 30 | + .apt_install("wget") |
| 31 | + .run_commands([f"wget https://raw.githubusercontent.com/pytorch/ao/refs/heads/main/examples/sam2_amg_server/requirements.txt"]) |
| 32 | + .pip_install_from_requirements( |
| 33 | + 'requirements.txt', |
| 34 | + ) |
| 35 | +) |
| 36 | + |
| 37 | +checkpoints = modal.Volume.from_name("checkpoints", create_if_missing=True) |
| 38 | + |
| 39 | +@app.function( |
| 40 | + image=image, |
| 41 | + gpu="H100", |
| 42 | + volumes={ |
| 43 | + TARGET + "checkpoints": checkpoints, |
| 44 | + # # mount the caches of torch.compile and friends |
| 45 | + # "/root/.nv": modal.Volume.from_name("torchao-sam-2-cli-nv-cache", create_if_missing=True), |
| 46 | + # "/root/.triton": modal.Volume.from_name( |
| 47 | + # "torchao-sam-2-cli-triton-cache", create_if_missing=True |
| 48 | + # ), |
| 49 | + # "/root/.inductor-cache": modal.Volume.from_name( |
| 50 | + # "torchao-sam-2-cli-inductor-cache", create_if_missing=True |
| 51 | + # ), |
| 52 | + }, |
| 53 | + timeout=60 * 60, |
| 54 | +) |
| 55 | +def eval(input_bytes, fast, furious): |
| 56 | + import torch |
| 57 | + import torchao |
| 58 | + import os |
| 59 | + |
| 60 | + import subprocess |
| 61 | + from pathlib import Path |
| 62 | + from git import Repo |
| 63 | + |
| 64 | + def download_file(url, filename): |
| 65 | + command = f"wget -O {filename} {url}" |
| 66 | + subprocess.run(command, shell=True, check=True) |
| 67 | + |
| 68 | + os.chdir(Path(TARGET)) |
| 69 | + download_file("https://raw.githubusercontent.com/pytorch/ao/refs/heads/climodal1/examples/sam2_amg_server/cli.py", "cli.py") |
| 70 | + download_file("https://raw.githubusercontent.com/pytorch/ao/refs/heads/climodal1/examples/sam2_amg_server/server.py", "server.py") |
| 71 | + # Create a Path object for the current directory |
| 72 | + current_directory = Path('.') |
| 73 | + |
| 74 | + with open('/tmp/dog.jpg', 'wb') as file: |
| 75 | + file.write(input_bytes) |
| 76 | + |
| 77 | + import sys |
| 78 | + sys.path.append(".") |
| 79 | + from cli import main as cli_main |
| 80 | + cli_main(Path(TARGET) / Path("checkpoints"), |
| 81 | + model_type="large", |
| 82 | + input_path="/tmp/dog.jpg", |
| 83 | + output_path="/tmp/dog_masked_2.png", |
| 84 | + verbose=True, |
| 85 | + fast=fast, |
| 86 | + furious=furious) |
| 87 | + |
| 88 | + return bytearray(open('/tmp/dog_masked_2.png', 'rb').read()) |
| 89 | + |
| 90 | +@app.local_entrypoint() |
| 91 | +def main(input_path, output_path, fast=False, furious=False): |
| 92 | + bytes = eval.remote(open(input_path, 'rb').read(), fast, furious) |
| 93 | + with open(output_path, "wb") as file: |
| 94 | + file.write(bytes) |
0 commit comments