Skip to content

Commit 9b748de

Browse files
authored
SAM2.1 and example README (#1048)
* Examples readme * Example image * curl command * checkpoint_path argument * SAM 2.1
1 parent ecc53bf commit 9b748de

File tree

6 files changed

+13
-128
lines changed

6 files changed

+13
-128
lines changed

examples/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Examples
2+
3+
Various example scripts and applications of torchao and PyTorch in general.

examples/sam2_amg_server/README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
To run this example you need to download the vit_h checkpoint and put it into a local folder named checkpoints
2-
3-
You can find the checkpoint for vit_h here: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
4-
5-
To read the image you also need to install opencv-python: https://pypi.org/project/opencv-python/
1+
## Example curl command
2+
```
3+
curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output path/to/output.png
4+
```

examples/sam2_amg_server/dog.jpg

97.5 KB
Loading

examples/sam2_amg_server/example.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
!DOCTYPE html>
1+
<!DOCTYPE html>
22
<html lang="en">
33
<head>
44
<meta charset="UTF-8">

examples/sam2_amg_server/sam2_hiera_l.yaml

Lines changed: 0 additions & 117 deletions
This file was deleted.

examples/sam2_amg_server/server.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ class GenerateRequest(BaseModel):
5353
num_steps: Optional[int] = 30
5454
seed: Optional[int] = 42
5555

56-
def main():
56+
def main(checkpoint_path):
5757

5858
from sam2.build_sam import build_sam2
5959
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
6060

6161
device = "cuda"
62-
sam2_checkpoint = "checkpoints/sam2_hiera_large.pt"
63-
model_cfg = "sam2_hiera_l.yaml"
62+
from pathlib import Path
63+
sam2_checkpoint = Path(checkpoint_path) / Path("sam2.1_hiera_large.pt")
64+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
6465
logging.basicConfig(level=logging.INFO)
6566

66-
logging.info(f"Loading model: {sam2_checkpoint}")
67+
logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
6768
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
68-
sam2.to(device=device)
6969

7070
mask_generator = SAM2AutomaticMaskGenerator(sam2) #, points_per_batch=None)
7171

0 commit comments

Comments
 (0)