Skip to content

Commit e7bb3a6

Browse files
cpuhrschsunjiweiswift
authored andcommitted
More SAM2-fast server improvements (pytorch#1285)
1 parent 5772310 commit e7bb3a6

File tree

7 files changed

+298
-96
lines changed

7 files changed

+298
-96
lines changed

examples/sam2_amg_server/README.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,19 @@ Experiments run on H100 and with batch size 1
2424
| mode | mIoU | mask count mismatch | avg. ms per request | max. memory (MiB (%)) | batch size | points per batch |
2525
| -------------- | ----------------- | ------------------- | ------------------- | --------------------- | ---------- | ---------------- |
2626
| baseline | 1.0 | 0 | 863 | 4013MiB (4%) | 1 | 64 |
27-
| ao | 1.0 | 0 | 840 | 4350MiB (4%) | 1 | 64 |
28-
| fast | 0.9897813200950623 | 191 | 661 | 3916MiB (4%) | 1 | 64 |
29-
| fast | 0.9897371530532837 | 192 | 388 | 50787MiB (52%) | 16 | 1024 |
30-
| fast + furious | 0.974319338798523 | 209 | 461 | 3453MiB (3%) | 1 | 64 |
31-
| fast + furious | 0.9702069759368896 | 196 | 195 | 48298MiB (49%) | 16 | 1024 |
27+
| ao | 0.9999980926513672 | 6 | 586 | | 1 | 64 |
28+
| fast | 0.9937329888343811 | 191 | 333 | | 1 | 1024 |
29+
| fast | 0.9937219619750977 | 192 | 324 | | 16 | 1024 |
30+
| fast + furious | 0.9804400205612183 | 292 | 131 | | 1 | 1024 |
31+
| fast + furious | 0.9806423187255859 | 282 | 130 | | 16 | 1024 |
3232

3333
mask count mismatch counts the number of requests where the number of masks differ from the baseline.
3434
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.
3535
We exclude these examples from the mIoU calculation.
36+
Difference in mask count seem to stem from even only slight reorderings in compute. For example preprocessing on GPU instead of CPU.
37+
A more relaxed way of measuring mIoU might be useful here to take into account slight differences in the number of masks, which may be caused by additional or missing sub-divisions.
3638

37-
The 'ao' mode is a copy of the baseline with modifications to make the code compile-able and improve the performance of fast.
39+
The 'ao' mode is a copy of the baseline with modifications to make the code more compile-able and speed up run length encoding
3840

3941
### 0. Download checkpoints and install requirements
4042

examples/sam2_amg_server/compare_rle_lists.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def iou(mask1, mask2):
1616
union = torch.logical_or(mask1, mask2)
1717
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))
1818

19+
1920
def compare_masks(masks, ref_masks, order_by_area=False, verbose=False):
20-
from torchao._models.sam2.utils.amg import rle_to_mask
2121
v0_areas = []
2222
v1_areas = []
2323
v0_masks = []
@@ -40,17 +40,20 @@ def compare_masks(masks, ref_masks, order_by_area=False, verbose=False):
4040
v0_masks = sorted(v0_masks, key=(lambda x: x[1]), reverse=True)
4141
v1_masks = sorted(v1_masks, key=(lambda x: x[1]), reverse=True)
4242
miou_sum = 0.0
43-
miou_count = 0
43+
miou_count = 0.0
44+
equal_count = 0
4445
for ((v0_mask, _), (v1_mask, _)) in zip(v0_masks, v1_masks):
4546
miou_sum += iou(v0_mask, v1_mask)
4647
miou_count += 1
48+
equal_count += torch.equal(v0_mask, v1_mask)
4749
if verbose:
4850
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")
4951

50-
return miou_sum, miou_count
52+
return miou_sum / miou_count, equal_count
5153

5254

53-
def main(path0, path1):
55+
def main(path0, path1, strict=False):
56+
# path0 are candidates and path1 the ground truth
5457
fail_count = 0
5558
miou_sum = 0.0
5659
miou_count = 0
@@ -59,11 +62,13 @@ def main(path0, path1):
5962
masks0 = json.loads(line0)
6063
masks1 = json.loads(line1)
6164
if masks0.keys() != masks1.keys():
62-
fail_count += 1
63-
continue
64-
s, c = compare_masks(masks0, masks1, order_by_area=True)
65-
miou_sum += s
66-
miou_count += c
65+
if strict:
66+
fail_count += 1
67+
continue
68+
69+
m, e = compare_masks(masks0, masks1, order_by_area=True)
70+
miou_sum += m
71+
miou_count += 1
6772

6873
print(f"fail_count: {fail_count} mIoU: {miou_sum / miou_count}")
6974

examples/sam2_amg_server/server.py

+45-29
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
inductorconfig.coordinate_descent_check_all_directions = True
3535
inductorconfig.allow_buffer_reuse = False
3636

37+
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
3738
torch._dynamo.config.capture_dynamic_output_shape_ops = True
3839

3940

@@ -173,7 +174,7 @@ def masks_to_rle_dict(masks):
173174

174175
# Queue to hold incoming requests
175176
request_queue = asyncio.Queue()
176-
batch_interval = 0.1 # Time interval to wait before processing a batch
177+
batch_interval = 0.01 # Time interval to wait before processing a batch
177178

178179

179180
def process_batch(batch, mask_generator):
@@ -186,7 +187,7 @@ def process_batch(batch, mask_generator):
186187
print(f"Processing batch of len {len(batch)} using generate_batch")
187188
masks = mask_generator.generate_batch(image_tensors)
188189
print(f"Took avg. {(time.time() - t) / len(batch)}s per batch entry")
189-
max_memory_allocated()
190+
# max_memory_allocated()
190191
return masks
191192

192193

@@ -220,17 +221,17 @@ async def lifespan(app: FastAPI):
220221
task.cancel()
221222

222223

223-
def benchmark_fn(func, inp, mask_generator):
224+
def benchmark_fn(func, inp, mask_generator, warmup=3, runs=10):
224225
torch.cuda.empty_cache()
225226
torch.cuda.reset_peak_memory_stats()
226-
logging.info("Running 3 warumup iterations.")
227-
for _ in range(3):
227+
logging.info("Running {warmup} warmup iterations.")
228+
for _ in range(warmup):
228229
func(inp, mask_generator)
229-
logging.info("Running 10 benchmark iterations.")
230+
logging.info("Running {runs} benchmark iterations.")
230231
t = time.time()
231-
for _ in range(10):
232+
for _ in range(runs):
232233
func(inp, mask_generator)
233-
print(f"Benchmark took {(time.time() - t)/10.0}s per iteration.")
234+
print(f"Benchmark took {(time.time() - t)/runs}s per iteration.")
234235
max_memory_allocated()
235236

236237

@@ -244,11 +245,11 @@ def max_memory_allocated():
244245

245246
def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
246247
from compare_rle_lists import compare_masks
247-
miou_sum, miou_count = compare_masks(masks, ref_masks, order_by_area=order_by_area, verbose=verbose)
248-
if miou_count == 0:
248+
miou, equal_count = compare_masks(masks, ref_masks, order_by_area=order_by_area, verbose=verbose)
249+
if equal_count == len(masks):
249250
print("Masks exactly match reference.")
250251
else:
251-
print(f"mIoU is {miou_sum / miou_count}")
252+
print(f"mIoU is {miou} with equal count {equal_count} out of {len(masks)}")
252253

253254

254255
def main(checkpoint_path,
@@ -290,7 +291,7 @@ def main(checkpoint_path,
290291

291292
logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
292293
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
293-
294+
294295
logging.info(f"Using {points_per_batch} points_per_batch")
295296
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
296297

@@ -299,18 +300,31 @@ def main(checkpoint_path,
299300
# TODO: Using CUDA graphs can cause numerical differences?
300301
mask_generator.predictor.model.image_encoder = torch.compile(
301302
mask_generator.predictor.model.image_encoder,
302-
# mode="max-autotune-no-cudagraphs",
303303
mode="max-autotune",
304304
fullgraph=True,
305305
dynamic=False,
306306
)
307307

308-
mask_generator._process_batch_fullgraph = torch.compile(
309-
mask_generator._process_batch_fullgraph,
308+
mask_generator.predictor.model.sam_prompt_encoder.forward = torch.compile(
309+
mask_generator.predictor.model.sam_prompt_encoder.forward,
310+
mode="max-autotune",
311+
fullgraph=True,
312+
dynamic=False,
313+
)
314+
315+
mask_generator.predictor._predict_masks = torch.compile(
316+
mask_generator.predictor._predict_masks,
317+
mode="max-autotune",
310318
fullgraph=True,
311-
dynamic=True,
319+
dynamic=False,
312320
)
313321

322+
# mask_generator.predictor._predict_masks_postprocess = torch.compile(
323+
# mask_generator.predictor._predict_masks_postprocess,
324+
# fullgraph=True,
325+
# dynamic=True,
326+
# )
327+
314328
if furious:
315329
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16)
316330
# NOTE: Not baseline feature
@@ -340,27 +354,28 @@ def main(checkpoint_path,
340354
unittest_fn(masks, ref_masks, order_by_area=True, verbose=verbose)
341355

342356
if benchmark:
357+
print(f"batch size {batch_size} dog benchmark")
343358
if batch_size == 1:
344-
print("batch size 1 test")
345359
benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator)
346-
benchmark_fn(image_tensor_to_masks, torch.tensor(image_tensor).transpose(0, 1).numpy(), mask_generator)
347360
else:
348-
print(f"batch size {batch_size} test")
349361
benchmark_fn(image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
350362

351-
print(f"batch size {batch_size} example shapes test")
352-
random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in example_shapes()]
353-
random_images = random_images[:batch_size]
354-
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)
363+
for i, shapes in enumerate([example_shapes(), example_shapes_2()]):
364+
print(f"batch size {batch_size} example shapes {i} benchmark")
365+
random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in shapes]
355366

356-
print(f"batch size {batch_size} example shapes 2 test")
357-
random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in example_shapes_2()]
358-
random_images = random_images[:batch_size]
359-
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)
367+
if batch_size == 1:
368+
[benchmark_fn(image_tensor_to_masks, r, mask_generator) for r in random_images]
369+
else:
370+
random_images = random_images[:batch_size]
371+
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)
360372

361373
if profile is not None:
362374
print(f"Saving profile under {profile}")
363-
profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
375+
if batch_size == 1:
376+
profiler_runner(profile, image_tensor_to_masks, image_tensor, mask_generator)
377+
else:
378+
profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
364379

365380
if dry:
366381
return
@@ -406,7 +421,8 @@ async def upload_image(image: UploadFile = File(...)):
406421
return StreamingResponse(buf, media_type="image/png")
407422

408423

409-
uvicorn.run(app, host=host, port=port, log_level="info")
424+
# uvicorn.run(app, host=host, port=port, log_level="info")
425+
uvicorn.run(app, host=host, port=port)
410426

411427
if __name__ == "__main__":
412428
fire.Fire(main)

0 commit comments

Comments
 (0)