Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit f3a2267

Browse files
authored
Segment Anything Fast example (#2802)
* Segment Anything Fast example * Segment Anything Fast example * Changes to make model inference faster * addressed review comments * code cleanup * review comments * added missing instruction * added python 3.10 dependency
1 parent 709e743 commit f3a2267

File tree

8 files changed

+251
-0
lines changed

8 files changed

+251
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
2+
## Segment Anything Fast
3+
4+
[Segment Anything Fast](https://github.com/pytorch-labs/segment-anything-fast) is the optimized version of [Segment Anything](https://github.com/facebookresearch/segment-anything) with 8x performance improvements compared to the original implementation. The improvements were achieved using native PyTorch.
5+
6+
Improvement in speed in achieved using
7+
- Torch.compile: A compiler for PyTorch models
8+
- GPU quantization: Accelerate models with reduced precision operations
9+
- Scaled Dot Product Attention (SDPA): Memory efficient attention implementations
10+
- Semi-Structured (2:4) Sparsity: A GPU optimized sparse memory format
11+
- Nested Tensor: Batch together non-uniformly sized data into a single Tensor, such as images of different sizes.
12+
- Custom operators with Triton: Write GPU operations using Triton Python DSL and easily integrate it into PyTorch’s various components with custom operator registration.
13+
14+
Details on how this is achieved can be found in this [blog](https://pytorch.org/blog/accelerating-generative-ai/)
15+
16+
#### Pre-requisites
17+
18+
Needs python 3.10
19+
20+
`cd` to the example folder `examples/large_models/segment_anything_fast`
21+
22+
Install `Segment Anything Fast` by running
23+
```
24+
chmod +x install_segment_anything_fast.sh
25+
source install_segment_anything_fast.sh
26+
```
27+
Segment Anything Fast needs the nightly version of PyTorch. Hence the script is uninstalling PyTorch, its domain libraries and installing the nightly version of PyTorch.
28+
29+
30+
### Step 1: Download the weights
31+
32+
```
33+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
34+
```
35+
36+
If you are not using A100 for inference, turn off the A100 specific optimization using
37+
```
38+
export SEGMENT_ANYTHING_FAST_USE_FLASH_4=0
39+
```
40+
41+
Depending on the available GPU memory, you need to edit the value of `process_batch_size` in `model-config.yaml`
42+
`process_batch_size` is the batch size for the decoding step. Use a smaller value for lower memory footprint.
43+
Higher value will result in faster inference. The following values were tested.
44+
45+
Example:
46+
- For `A10G` : `process_batch_size=8`
47+
- For `A100` : `process_batch_size=16`
48+
49+
50+
### Step 2: Generate mar or tgz file
51+
52+
```
53+
torch-model-archiver --model-name sam-fast --version 1.0 --handler custom_handler.py --config-file model-config.yaml --archive-format tgz
54+
```
55+
56+
### Step 3: Add the tgz file to model store
57+
58+
```
59+
mkdir model_store
60+
mv sam-fast.tar.gz model_store
61+
```
62+
63+
### Step 4: Start torchserve
64+
65+
```
66+
torchserve --start --ncs --model-store model_store --models sam-fast.tar.gz
67+
```
68+
69+
### Step 5: Run inference
70+
71+
```
72+
python inference.py
73+
```
74+
75+
results in
76+
77+
![kitten_mask_sam_fast](./kitten_mask_fast.png)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import base64
2+
import io
3+
import logging
4+
import pickle
5+
6+
import cv2
7+
import numpy as np
8+
import torch
9+
from PIL import Image
10+
from segment_anything_fast import SamAutomaticMaskGenerator, sam_model_fast_registry
11+
12+
from ts.handler_utils.timer import timed
13+
from ts.torch_handler.base_handler import BaseHandler
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class SegmentAnythingFastHandler(BaseHandler):
19+
def __init__(self):
20+
super().__init__()
21+
self.mask_generator = None
22+
self.initialized = False
23+
24+
def initialize(self, ctx):
25+
properties = ctx.system_properties
26+
self.device = "cpu"
27+
if torch.cuda.is_available() and properties.get("gpu_id") is not None:
28+
self.map_location = "cuda"
29+
self.device = torch.device(
30+
self.map_location + ":" + str(properties.get("gpu_id"))
31+
)
32+
33+
model_type = ctx.model_yaml_config["handler"]["model_type"]
34+
sam_checkpoint = ctx.model_yaml_config["handler"]["sam_checkpoint"]
35+
process_batch_size = ctx.model_yaml_config["handler"]["process_batch_size"]
36+
37+
self.model = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
38+
self.model.to(self.device)
39+
40+
self.mask_generator = SamAutomaticMaskGenerator(
41+
self.model, process_batch_size=process_batch_size, output_mode="coco_rle"
42+
)
43+
44+
logger.info(
45+
f"Model weights {sam_checkpoint} for {model_type} loaded successfully with process batch size {process_batch_size}"
46+
)
47+
self.initialized = True
48+
49+
@timed
50+
def preprocess(self, data):
51+
images = []
52+
for row in data:
53+
image = row.get("data") or row.get("body")
54+
if isinstance(image, str):
55+
# if the image is a string of bytesarray.
56+
image = base64.b64decode(image)
57+
58+
# If the image is sent as bytesarray
59+
if isinstance(image, (bytearray, bytes)):
60+
image = Image.open(io.BytesIO(image))
61+
else:
62+
# if the image is a list
63+
image = torch.FloatTensor(image)
64+
65+
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
66+
images.append(image)
67+
68+
return images
69+
70+
@timed
71+
def inference(self, data):
72+
assert (
73+
len(data) == 1
74+
), "SAM AutoMaticMaskGenerator currently supports batch size of 1"
75+
return self.mask_generator.generate(data[0])
76+
77+
@timed
78+
def postprocess(self, data):
79+
# Serialize the output using Pickle
80+
serialized_data = pickle.dumps(data)
81+
82+
# Encode the serialized data as Base64
83+
base64_encoded_data = base64.b64encode(serialized_data).decode("utf-8")
84+
85+
return [base64_encoded_data]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import base64
2+
import pickle
3+
4+
import cv2
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import requests
8+
from pycocotools import mask as coco_mask
9+
10+
url = "http://localhost:8080/predictions/sam-fast"
11+
image_path = "./kitten.jpg"
12+
13+
14+
def show_anns(anns):
15+
if len(anns) == 0:
16+
return
17+
for i in range(len(anns)):
18+
anns[i]["segmentation"] = coco_mask.decode(anns[i]["segmentation"])
19+
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
20+
ax = plt.gca()
21+
ax.set_autoscale_on(False)
22+
23+
img = np.ones(
24+
(
25+
sorted_anns[0]["segmentation"].shape[0],
26+
sorted_anns[0]["segmentation"].shape[1],
27+
4,
28+
)
29+
)
30+
img[:, :, 3] = 0
31+
for ann in sorted_anns:
32+
m = ann["segmentation"].astype(bool)
33+
color_mask = np.concatenate([np.random.random(3), [0.35]])
34+
img[m] = color_mask
35+
ax.imshow(img)
36+
37+
38+
# Send Inference request to TorchServe
39+
file = {"body": open(image_path, "rb")}
40+
res = requests.post(url, files=file)
41+
42+
# Decode the Base64 encoded data (if needed)
43+
decoded_data = base64.b64decode(res.text)
44+
45+
# Deserialize the data using Pickle
46+
masks = pickle.loads(decoded_data)
47+
48+
49+
# Plot the segmentation mask on the image
50+
image = cv2.imread(image_path)
51+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
52+
plt.figure(figsize=(image.shape[1] / 100.0, image.shape[0] / 100.0), dpi=100)
53+
plt.imshow(image)
54+
show_anns(masks)
55+
plt.axis("off")
56+
plt.tight_layout()
57+
plt.savefig("kitten_mask_fast.png", format="png")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
3+
# Uninstall torchtext, torchdata, torch, torchvision, and torchaudio
4+
pip uninstall torchtext torchdata torch torchvision torchaudio -y
5+
6+
# Install nightly PyTorch and torchvision from the specified index URL
7+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed
8+
9+
# Optional: Display the installed PyTorch and torchvision versions
10+
python -c "import torch; print('PyTorch version:', torch.__version__)"
11+
python -c "import torchvision; print('torchvision version:', torchvision.__version__)"
12+
13+
echo "PyTorch and torchvision updated successfully!"
14+
15+
# Install the segment-anything-fast package from GitHub
16+
pip install git+https://github.com/pytorch-labs/segment-anything-fast.git
17+
18+
echo "Segment Anything Fast installed successfully!"
19+
20+
echo "Installing other dependencies"
21+
pip install opencv-python matplotlib pycocotools
108 KB
Loading
573 KB
Loading
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
responseTimeout: 300
2+
handler:
3+
profile: true
4+
model_type: "vit_h"
5+
sam_checkpoint: "/home/ubuntu/serve/examples/large_models/segment_anything_fast/sam_vit_h_4b8939.pth"
6+
process_batch_size: 8

ts_scripts/spellcheck_conf/wordlist.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,3 +1138,8 @@ FlashAttention
11381138
GenAI
11391139
prem
11401140
CachingMetric
1141+
DSL
1142+
SDPA
1143+
sam
1144+
zlib
1145+

0 commit comments

Comments
 (0)