Skip to content

Commit 00a8d29

Browse files
authored
SAM2: More experimental data (#1468)
1 parent 9708538 commit 00a8d29

13 files changed

+1521
-242
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from pathlib import Path
2+
from tqdm import tqdm
3+
import json
4+
import fire
5+
import numpy as np
6+
from scipy import ndimage
7+
import matplotlib.pyplot as plt
8+
from datetime import datetime
9+
from server import file_bytes_to_image_tensor
10+
from server import show_anns
11+
from server import model_type_to_paths
12+
from server import MODEL_TYPES_TO_MODEL
13+
from server import masks_to_rle_dict
14+
from server import max_memory_allocated
15+
from io import BytesIO
16+
from torchao._models.sam2.utils.amg import rle_to_mask
17+
from torchao._models.sam2.utils.amg import area_from_rle
18+
19+
20+
def timestamped_print(*args, **kwargs):
21+
# Get the current timestamp
22+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
23+
# Prepend the timestamp to the original print arguments
24+
print(f"[{timestamp}]", *args, **kwargs)
25+
26+
27+
# From https://github.com/pytorch-labs/segment-anything-fast/blob/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/experiments/data.py
28+
# All credit to vkuzo
29+
def _get_center_point(mask):
30+
"""
31+
This is a rudimentary version of https://arxiv.org/pdf/2304.02643.pdf,
32+
section D.1.Point Sampling
33+
34+
From the paper: "The first point is chosen deterministically as the point
35+
farthest from the object boundary."
36+
37+
The code below is an approximation of this.
38+
39+
First, we try to calculate the center of mass. If it's inside the mask, we
40+
stop here.
41+
42+
The centroid may be outside of the mask for some mask shapes. In this case
43+
we do a slow hack, specifically, we check for the
44+
minumum of the maximum distance from the boundary in four directions
45+
(up, right, down, left), and take the point with the maximum of these
46+
minimums. Note: this is not performant for large masks.
47+
48+
Returns the center point in (x, y) format
49+
"""
50+
51+
# try the center of mass, keep it if it's inside the mask
52+
com_y, com_x = ndimage.center_of_mass(mask)
53+
com_y, com_x = int(round(com_y, 0)), int(round(com_x, 0))
54+
if mask[com_y][com_x]:
55+
return (com_x, com_y)
56+
57+
# if center of mass didn't work, do the slow manual approximation
58+
59+
# up, right, down, left
60+
# TODO(future): approximate better by adding more directions
61+
distances_to_check_deg = [0, 90, 180, 270]
62+
63+
global_min_max_distance = float('-inf')
64+
global_coords = None
65+
# For now, terminate early to speed up the calculation as long as
66+
# the point sample is gooe enough. This sacrifices the quality of point
67+
# sampling for speed. In the future we can make this more accurate.
68+
DISTANCE_GOOD_ENOUGH_THRESHOLD = 20
69+
70+
# Note: precalculating the bounding box could be somewhat
71+
# helpful, but checked the performance gain and it's not much
72+
# so leaving it out to keep the code simple.
73+
# Note: tried binary search instead of incrementing by one to
74+
# travel up/right/left/down, but that does not handle masks
75+
# with all shapes properly (there could be multiple boundaries).
76+
for row_idx in range(mask.shape[0]):
77+
for col_idx in range(mask.shape[1]):
78+
cur_point = mask[row_idx, col_idx]
79+
80+
# skip points inside bounding box but outside mask
81+
if not cur_point:
82+
continue
83+
84+
max_distances = []
85+
for direction in distances_to_check_deg:
86+
# TODO(future) binary search instead of brute forcing it if we
87+
# need a speedup, with a cache it doesn't really matter though
88+
if direction == 0:
89+
# UP
90+
cur_row_idx = row_idx
91+
92+
while cur_row_idx >= 0 and mask[cur_row_idx, col_idx]:
93+
cur_row_idx = cur_row_idx - 1
94+
cur_row_idx += 1
95+
distance = row_idx - cur_row_idx
96+
max_distances.append(distance)
97+
98+
elif direction == 90:
99+
# RIGHT
100+
cur_col_idx = col_idx
101+
102+
while cur_col_idx <= mask.shape[1] - 1 and \
103+
mask[row_idx, cur_col_idx]:
104+
cur_col_idx += 1
105+
cur_col_idx -= 1
106+
distance = cur_col_idx - col_idx
107+
max_distances.append(distance)
108+
109+
elif direction == 180:
110+
# DOWN
111+
cur_row_idx = row_idx
112+
while cur_row_idx <= mask.shape[0] - 1 and \
113+
mask[cur_row_idx, col_idx]:
114+
cur_row_idx = cur_row_idx + 1
115+
cur_row_idx -= 1
116+
distance = cur_row_idx - row_idx
117+
max_distances.append(distance)
118+
119+
elif direction == 270:
120+
# LEFT
121+
cur_col_idx = col_idx
122+
while cur_col_idx >= 0 and mask[row_idx, cur_col_idx]:
123+
cur_col_idx -= 1
124+
cur_col_idx += 1
125+
distance = col_idx - cur_col_idx
126+
max_distances.append(distance)
127+
128+
min_max_distance = min(max_distances)
129+
if min_max_distance > global_min_max_distance:
130+
global_min_max_distance = min_max_distance
131+
global_coords = (col_idx, row_idx)
132+
if global_min_max_distance >= DISTANCE_GOOD_ENOUGH_THRESHOLD:
133+
break
134+
135+
return global_coords
136+
137+
138+
# TODO: Create prompts
139+
# Get prompts for each mask and prompt for largest mask
140+
# Use those prompts as input for generate data
141+
142+
# Create 3 images for each task type
143+
# amg: all masks without center point
144+
# sps: one mask with center point
145+
# mps: multiple masks with center points
146+
147+
148+
def main_docstring():
149+
return f"""
150+
Args:
151+
checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints
152+
model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())}
153+
input_path (str): Path to input image
154+
output_path (str): Path to output image
155+
"""
156+
157+
158+
def main(
159+
checkpoint_path,
160+
model_type,
161+
input_paths,
162+
amg_mask_folder,
163+
output_folder,
164+
output_format="png",
165+
verbose=False,
166+
fast=False,
167+
furious=False,
168+
load_fast="",
169+
overwrite=False,
170+
store_image=False,
171+
baseline=False,
172+
):
173+
# Input path validation
174+
input_paths = [
175+
Path(input_path.strip())
176+
for input_path in Path(input_paths).read_text().splitlines()
177+
]
178+
# We include parent folder to reduce possible duplicates
179+
filenames = [
180+
Path(input_path.parent.name) / Path(input_path.name)
181+
for input_path in input_paths
182+
]
183+
if len(filenames) != len(set(filenames)):
184+
raise ValueError("Expected input_paths to have unique filenames.")
185+
if any(not input_path.is_file() for input_path in input_paths):
186+
raise ValueError("One of the input paths does not point to a file.")
187+
if not Path(amg_mask_folder).is_dir():
188+
raise ValueError(f"Expected {amg_mask_folder} to be a directory.")
189+
rle_json_paths = [
190+
Path(amg_mask_folder)
191+
/ Path(filename.parent)
192+
/ Path(filename.stem + "_masks.json")
193+
for filename in filenames
194+
]
195+
for p in rle_json_paths:
196+
if not p.exists():
197+
raise ValueError(
198+
f"Expected mask {p} to exist."
199+
)
200+
201+
# Output path validation
202+
if not Path(output_folder).is_dir():
203+
raise ValueError(f"Expected {output_folder} to be a directory.")
204+
205+
output_image_paths = [
206+
(Path(output_folder) / filename).with_suffix("." + output_format)
207+
for filename in filenames
208+
]
209+
if not overwrite and any(p.exists() for p in output_image_paths):
210+
raise ValueError(
211+
"Output image path already exists, but --overwrite was not specified."
212+
)
213+
214+
output_json_paths = [
215+
Path(output_folder)
216+
/ Path(filename.parent)
217+
/ Path(filename.stem + "_meta.json")
218+
for filename in filenames
219+
]
220+
if not overwrite and any(p.exists() for p in output_json_paths):
221+
raise ValueError(
222+
"Output json path already exists, but --overwrite was not specified."
223+
)
224+
225+
for input_path, filename, output_image_path, rle_json_path, output_json_path in tqdm(
226+
zip(input_paths, filenames, output_image_paths, rle_json_paths, output_json_paths),
227+
total=len(input_paths),
228+
):
229+
input_bytes = bytearray(open(input_path, "rb").read())
230+
image_tensor = file_bytes_to_image_tensor(input_bytes)
231+
if verbose:
232+
timestamped_print(f"Loading rle from {rle_json_path}")
233+
with open(rle_json_path, "r") as file:
234+
rle_dict = json.load(file)
235+
masks = {}
236+
for key in rle_dict:
237+
masks[key] = {'segmentation': rle_dict[key],
238+
'area': area_from_rle(rle_dict[key]),
239+
'center_point': _get_center_point(rle_to_mask(rle_dict[key]))}
240+
241+
if verbose:
242+
timestamped_print(
243+
f"Generating mask annotations for input image {filename}."
244+
)
245+
plt.figure(
246+
figsize=(image_tensor.shape[1] / 100.0, image_tensor.shape[0] / 100.0),
247+
dpi=100,
248+
)
249+
plt.imshow(image_tensor)
250+
# seed for consistent coloring
251+
# Converts segmentation to binary mask for annotations
252+
show_anns(list(masks.values()), rle_to_mask, seed=42)
253+
plt.axis("off")
254+
plt.tight_layout()
255+
256+
points = np.array([mask['center_point'] for mask in masks.values()])
257+
ax = plt.gca()
258+
marker_size = 375
259+
ax.scatter(points[:, 0], points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
260+
261+
buf = BytesIO()
262+
plt.savefig(buf, format=output_format)
263+
buf.seek(0)
264+
output_bytes = buf.getvalue()
265+
output_image_path.parent.mkdir(parents=False, exist_ok=True)
266+
267+
if verbose:
268+
timestamped_print(f"Storing result image under {output_image_path}")
269+
with open(output_image_path, "wb") as file:
270+
file.write(output_bytes)
271+
272+
# Back to RLE representation
273+
for key in masks:
274+
masks[key]['segmentation'] = rle_dict[key]
275+
276+
if verbose:
277+
timestamped_print(f"Storing meta under {output_json_path}")
278+
279+
with open(output_json_path, "w") as file:
280+
file.write(json.dumps(masks, indent=4))
281+
282+
plt.close()
283+
284+
285+
main.__doc__ = main_docstring()
286+
if __name__ == "__main__":
287+
fire.Fire(main)

examples/sam2_amg_server/compare_rle_lists.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import fire
2+
from pathlib import Path
23
import torch
34
import json
45
from torchao._models.sam2.utils.amg import rle_to_mask
@@ -42,34 +43,73 @@ def compare_masks(masks, ref_masks, order_by_area=False, verbose=False):
4243
miou_sum = 0.0
4344
miou_count = 0.0
4445
equal_count = 0
45-
for ((v0_mask, _), (v1_mask, _)) in zip(v0_masks, v1_masks):
46+
for i, ((v0_mask, _), (v1_mask, _)) in enumerate(zip(v0_masks, v1_masks)):
4647
miou_sum += iou(v0_mask, v1_mask)
4748
miou_count += 1
4849
equal_count += torch.equal(v0_mask, v1_mask)
4950
if verbose:
50-
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")
51+
# If sorted we don't map back to the original key
52+
# TODO: Could recover the indices for this
53+
if order_by_area:
54+
print(f"IoU is {iou(v0_mask, v1_mask)}")
55+
else:
56+
print(f"mask {i} IoU is iou(v0_mask, v1_mask)")
5157

52-
return miou_sum / miou_count, equal_count
58+
return float((miou_sum / miou_count).item()), equal_count
5359

5460

55-
def main(path0, path1, strict=False):
61+
def compare_masks_str(str0, str1, strict):
62+
masks0 = json.loads(str0)
63+
masks1 = json.loads(str1)
64+
if masks0.keys() != masks1.keys():
65+
if strict:
66+
return None, None, True
67+
68+
# TODO: We might not want to order_by_area when comparing
69+
# masks from specific input points.
70+
m, e = compare_masks(masks0, masks1, order_by_area=True)
71+
return m, e, False
72+
73+
74+
def compare(path0, path1, strict=False, compare_folders=False):
5675
# path0 are candidates and path1 the ground truth
5776
fail_count = 0
5877
miou_sum = 0.0
5978
miou_count = 0
60-
with open(path0, 'r') as f0, open(path1, 'r') as f1:
61-
for line0, line1 in zip(f0, f1):
62-
masks0 = json.loads(line0)
63-
masks1 = json.loads(line1)
64-
if masks0.keys() != masks1.keys():
65-
if strict:
79+
if compare_folders:
80+
path0, path1 = Path(path0), Path(path1)
81+
assert path0.is_dir()
82+
assert path1.is_dir()
83+
mask_files0 = [f.relative_to(path0) for f in list(path0.rglob('*.json'))]
84+
mask_files1 = [f.relative_to(path1) for f in list(path1.rglob('*.json'))]
85+
assert all(m0 == m1 for (m0, m1) in zip(mask_files0, mask_files1))
86+
for (m0, m1) in zip(mask_files0, mask_files1):
87+
with open(path0 / m0, 'r') as f0, open(path1 / m1, 'r') as f1:
88+
m, e, fail = compare_masks_str(f0.read(), f1.read(), strict)
89+
if fail:
6690
fail_count += 1
67-
continue
91+
else:
92+
miou_sum += m
93+
miou_count += 1
94+
95+
else:
96+
with open(path0, 'r') as f0, open(path1, 'r') as f1:
97+
for line0, line1 in zip(f0, f1):
98+
m, e, fail = compare_masks_str(line0, line1, strict)
99+
if fail:
100+
fail_count += 1
101+
else:
102+
miou_sum += m
103+
miou_count += 1
104+
105+
return miou_count, miou_sum, fail_count
68106

69-
m, e = compare_masks(masks0, masks1, order_by_area=True)
70-
miou_sum += m
71-
miou_count += 1
72107

108+
def main(path0, path1, strict=False, compare_folders=False):
109+
miou_count, miou_sum, fail_count = compare(path0,
110+
path1,
111+
strict=strict,
112+
compare_folders=compare_folders)
73113
print(f"fail_count: {fail_count} mIoU: {miou_sum / miou_count}")
74114

75115

0 commit comments

Comments
 (0)