Production-grade deployment of Meta's Segment Anything Models (SAM1, SAM2, SAM3) using NVIDIA Triton Inference Server.
- CLAUDE.md - Comprehensive architecture and deployment details
- SAM1 - Original Segment Anything Model (ViT-H, ViT-L, ViT-B)
- SAM2.1 - 40% faster inference with Hiera backbone (tiny, small, base_plus, large)
- SAM3 - SAM3 Tracker with multi-scale embeddings
- Enterprise-grade: Industry-standard inference protocol
- Performance: GPU-accelerated ONNX Runtime with dynamic batching support
- Scalability: Native multi-GPU support with load balancing
- Observability: Built-in Prometheus metrics
- Flexibility: Hot-reload models without downtime
- NVIDIA GPU with CUDA support
- Docker with NVIDIA Container Toolkit
- Pixi - Modern Python package manager
# Test Docker can access GPUs
docker run --rm --gpus all nvcr.io/nvidia/tritonserver:25.01-py3 nvidia-smi# 1. Install Pixi (if not already installed)
curl -fsSL https://pixi.sh/install.sh | bash
# 2. Install dependencies
pixi install
# 3. Choose your setup:
# Option A: Setup ALL models (SAM1 + SAM2 + SAM3) - ~5GB download
pixi run setup-all
# Option B: Setup only SAM2 (recommended for quick start)
pixi run setup-sam2
# Option C: Setup individual models
pixi run setup-sam1 # SAM1 only (~2.5GB)
pixi run setup-sam2 # SAM2 only (~350MB)
pixi run setup-sam3 # SAM3 only (~1GB, pre-exported ONNX)
# 4. Start Triton server
docker compose up -d
# 5. Verify deployment
curl http://localhost:8000/v2/models# Test SAM2
pixi run test-sam2
# Test SAM3
pixi run test-sam3
# Speculative request stress test
pixi run test-speculative| Model | Input Size | Embedding Shape | Best For |
|---|---|---|---|
| SAM1 | 1024x1024 | (1, 256, 64, 64) | Legacy compatibility, proven accuracy |
| SAM2 | 1024x1024 | (1, 256, 64, 64) | Production default, video support |
| SAM3 | 1008x1008 | 3 multi-scale | Latest features, text prompts |
| Model | Parameters | Memory | Speed | Use Case |
|---|---|---|---|---|
| tiny | 39M | 2GB | 91 FPS | Edge devices, real-time |
| small | 46M | 2.5GB | 85 FPS | Balanced |
| base_plus | 81M | 4GB | 64 FPS | Production default |
| large | 224M | 8GB | 40 FPS | Maximum quality |
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β NVIDIA Triton Server β
β β
β βββββββββββββββββββ βββββββββββββββββββ βββββββββββββββββββ β
β β SAM1 Encoder β β SAM2 Encoder β β SAM3 Encoder β β
β β (1024x1024) β β (1024x1024) β β (1008x1008) β β
β β β (256,64,64) β β β (256,64,64) β β β 3 embeddings β β
β ββββββββββ¬βββββββββ ββββββββββ¬βββββββββ ββββββββββ¬βββββββββ β
β β β β β
β ββββββββββΌβββββββββ ββββββββββΌβββββββββ ββββββββββΌβββββββββ β
β β SAM1 Decoder β β SAM2 Decoder β β SAM3 Decoder β β
β β + prompts β β + prompts β β + prompts β β
β β β masks β β β masks β β β masks β β
β βββββββββββββββββββ βββββββββββββββββββ βββββββββββββββββββ β
β β
β Ports: HTTP (8000) | gRPC (8001) | Metrics (8002) β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
All SAM models use a two-stage inference pipeline optimized for interactive segmentation:
- Run once per image
- Generates reusable embeddings
- ~200-800ms depending on model
- Run many times per image
- Uses cached embeddings + point prompts
- ~10-30ms per prediction
triton-sam2/
βββ README.md # This file
βββ CLAUDE.md # Detailed architecture docs
βββ docker-compose.yml # Triton server deployment
βββ pyproject.toml # Pixi tasks and dependencies
β
βββ triton_sam/ # Python client module
β βββ client.py # SAM2TritonClient (sync)
β βββ speculative_client.py # Async client with cancellation
β βββ tests/
β
βββ scripts/
β βββ download_sam1.sh # Download SAM1 checkpoints
β βββ download_sam2.sh # Download SAM2 checkpoints
β βββ download_sam3.sh # Download SAM3 checkpoints
β βββ download_sam3_onnx.py # Download pre-exported SAM3 ONNX
β βββ export_sam1_to_onnx.py # Export SAM1 to ONNX
β βββ export_sam2_to_onnx.py # Export SAM2 to ONNX
β βββ export_sam3_to_onnx.py # Export SAM3 to ONNX
β
βββ model_repository/ # Triton model repository
β βββ sam1_encoder/
β β βββ 1/model.onnx
β β βββ config.pbtxt
β βββ sam1_decoder/
β β βββ 1/model.onnx
β β βββ config.pbtxt
β βββ sam2_encoder/
β β βββ 1/model.onnx
β β βββ config.pbtxt
β βββ sam2_decoder/
β β βββ 1/model.onnx
β β βββ config.pbtxt
β βββ sam3_encoder/
β β βββ 1/vision_encoder.onnx
β β βββ config.pbtxt
β βββ sam3_decoder/
β βββ 1/prompt_encoder_mask_decoder.onnx
β βββ config.pbtxt
β
βββ checkpoints/ # Downloaded model weights
βββ sam1_repo/ # Cloned segment-anything repo
βββ sam2_repo/ # Cloned segment-anything-2 repo
# Complete setup (all models)
pixi run setup-all # SAM1 + SAM2 + SAM3
# Individual model setup
pixi run setup-sam1 # Download, clone repo, export SAM1
pixi run setup-sam2 # Download, clone repo, export SAM2 (alias: setup)
pixi run setup-sam3 # Download pre-exported SAM3 ONNX# SAM1 checkpoints
pixi run download-sam1-h # ViT-Huge (2.5GB, recommended)
pixi run download-sam1-l # ViT-Large
pixi run download-sam1-b # ViT-Base
# SAM2 checkpoints
pixi run download-tiny # 38.9M params
pixi run download-small # 46M params
pixi run download-base # 80.8M params (recommended)
pixi run download-large # 224.4M params
# SAM3 ONNX models
pixi run download-sam3-onnx # Pre-exported from HuggingFacepixi run export-sam1 # Export SAM1 to ONNX
pixi run export-sam2 # Export SAM2 to ONNX
pixi run export-onnx # Alias for export-sam2pixi run test-sam2 # Basic SAM2 inference test
pixi run test-sam3 # SAM3 inference test
pixi run test-speculative # Stress test with cancellationpixi run benchmark-sam2 # SAM2 performance benchmark
pixi run benchmark-sam3 # SAM3 performance benchmarkfrom triton_sam import SAM2TritonClient
# Initialize client (supports sam2 or sam3)
client = SAM2TritonClient("localhost:8000", model_type="sam2")
# Encode image once (cached)
client.set_image("image.jpg")
# Predict masks from point prompts
masks, iou = client.predict(
point_coords=[[512, 512]], # (x, y) in original image space
point_labels=[1] # 1=foreground, 0=background
)
# Threshold logits at 0 for binary mask
binary_mask = (masks[0, 0] > 0).astype(np.uint8)from triton_sam import SpeculativeSAM2Client, queue_multiple_requests
import asyncio
async def interactive_segmentation():
client = SpeculativeSAM2Client("localhost:8000")
client.set_image("image.jpg")
session_id = "user_session_1"
# Queue many requests (simulating mouse movement)
coords_list = [np.array([[x, y]]) for x, y in mouse_positions]
labels_list = [np.array([1]) for _ in mouse_positions]
tasks = await queue_multiple_requests(
client, coords_list, labels_list, session_id
)
# Cancel intermediate requests when user stops
client.cancel_session_requests(session_id)
# Get final result
result = await wait_for_latest_result(tasks, client, session_id)This Triton deployment integrates with the SAM Service FastAPI application. See the SAM_service/ directory for the API that provides:
POST /embedded_model- Generate embeddings (supports model_version param)POST /from_model- Generate masks from embeddingsGET /models- List available modelsGET /health- Health check
# Server health
curl http://localhost:8000/v2/health/ready
# List models
curl http://localhost:8000/v2/models
# Prometheus metrics
curl http://localhost:8002/metrics | grep nv_inference
# Docker logs
docker compose logs -f# Verify NVIDIA Container Toolkit
docker run --rm --gpus all nvidia/cuda:12.1.0-base-ubuntu22.04 nvidia-smi# Check model files exist
ls -la model_repository/*/1/
# Check Triton logs
docker compose logs triton | grep -i error- Use smaller model (tiny or small for SAM2)
- Reduce instance count in config.pbtxt
- Check other GPU processes:
nvidia-smi
- Encoder: ~300ms per image
- Decoder: ~15ms per mask
- End-to-end (1 image, 10 masks): ~450ms
- SAM1 ViT-H: ~4GB
- SAM2 base_plus: ~4GB
- SAM3: ~6GB
Released under the Janelia Open-Source Software License.
- SAM Paper - Original Segment Anything
- SAM2 Paper - Segment Anything 2
- SAM2 GitHub
- NVIDIA Triton