Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions sam3/model/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(

def _encode_xy(self, x, y):
# The positions are expected to be normalized
assert len(x) == len(y) and x.ndim == y.ndim == 1
# torch._check(len(x) == len(y) and x.ndim == y.ndim == 1)
x_embed = x * self.scale
y_embed = y * self.scale

Expand All @@ -62,12 +62,8 @@ def _encode_xy(self, x, y):

pos_x = x_embed[:, None] / dim_t
pos_y = y_embed[:, None] / dim_t
pos_x = torch.stack(
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
).flatten(1)
pos_y = torch.stack(
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
).flatten(1)
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y

@torch.no_grad()
Expand Down
97 changes: 31 additions & 66 deletions scripts/benchmark_sam3_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,31 @@ def _prepare_image(image: torch.Tensor, size: int) -> torch.Tensor:

def _make_inputs(model, image: torch.Tensor, prompts):
device = image.device
num_prompts = len(prompts)
num_images = int(image.shape[0])

tokenizer = model.backbone.language_backbone.tokenizer
token_ids = tokenizer(prompts, context_length=32).to(device)

img_ids = torch.arange(num_images, device=device, dtype=torch.long)
img_ids = img_ids.repeat_interleave(num_prompts)
text_ids = torch.arange(num_prompts, device=device, dtype=torch.long)
text_ids = text_ids.repeat(num_images)

box_embeddings = torch.zeros(1, num_prompts, 4, device=device)
box_mask = torch.zeros(num_prompts, 1, device=device, dtype=torch.bool)
box_labels = torch.zeros(1, num_prompts, device=device, dtype=torch.long)

return (
image,
token_ids,
img_ids,
text_ids,
box_embeddings,
box_mask,
box_labels,
)


def _run_full_model(model, inputs):
(
images,
token_ids,
img_ids,
text_ids,
box_embeddings,
box_mask,
box_labels,
) = inputs
images, token_ids = inputs
num_images = images.shape[0]
num_prompts = token_ids.shape[0]
device = images.device
bs = num_images * num_prompts

img_ids = torch.arange(num_images, device=device, dtype=torch.long)
img_ids = img_ids.repeat_interleave(num_prompts)
text_ids = torch.arange(num_prompts, device=device, dtype=torch.long)
text_ids = text_ids.repeat(num_images)

box_embeddings = torch.zeros(1, bs, 4, device=device)
box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool)
box_labels = torch.zeros(1, bs, device=device, dtype=torch.long)
backbone_out = model.backbone.forward_image(images)
text_encoder = model.backbone.language_backbone
_, text_tokens = text_encoder.encoder(token_ids)
Expand Down Expand Up @@ -113,15 +102,20 @@ def _make_decoder_only_inputs_from_model(
text_attention_mask,
inputs,
):
(
images,
token_ids,
img_ids,
text_ids,
box_embeddings,
box_mask,
box_labels,
) = inputs
images, token_ids = inputs
num_images = images.shape[0]
num_prompts = token_ids.shape[0]
device = images.device
bs = num_images * num_prompts

img_ids = torch.arange(num_images, device=device, dtype=torch.long)
img_ids = img_ids.repeat_interleave(num_prompts)
text_ids = torch.arange(num_prompts, device=device, dtype=torch.long)
text_ids = text_ids.repeat(num_images)

box_embeddings = torch.zeros(1, bs, 4, device=device)
box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool)
box_labels = torch.zeros(1, bs, device=device, dtype=torch.long)
backbone_out = {
"backbone_fpn": backbone_fpn,
"vision_pos_enc": vision_pos_enc,
Expand All @@ -147,9 +141,7 @@ def _make_decoder_only_inputs_from_model(
prompt, prompt_mask, backbone_out = model._encode_prompt(
backbone_out, find_input, geometric_prompt
)
backbone_out, encoder_out, _ = model._run_encoder(
backbone_out, find_input, prompt, prompt_mask
)
backbone_out, encoder_out, _ = model._run_encoder(backbone_out, find_input, prompt, prompt_mask)
return (
backbone_out["backbone_fpn"],
img_ids,
Expand Down Expand Up @@ -286,32 +278,7 @@ def encoder_from_outputs(image_out, text_out):
def encoder_fn():
encoder_from_outputs(cached_image_out, cached_text_out)

(
pipeline_images,
pipeline_token_ids,
pipeline_img_ids,
pipeline_text_ids,
pipeline_box_embeddings,
pipeline_box_mask,
pipeline_box_labels,
) = inputs
if pipeline_token_ids.shape[0] < 2:
repeat = 2 // pipeline_token_ids.shape[0]
pipeline_token_ids = pipeline_token_ids.repeat(repeat, 1)
pipeline_img_ids = pipeline_img_ids.repeat(repeat)
pipeline_text_ids = pipeline_text_ids.repeat(repeat)
pipeline_box_embeddings = pipeline_box_embeddings.repeat(1, repeat, 1)
pipeline_box_mask = pipeline_box_mask.repeat(repeat, 1)
pipeline_box_labels = pipeline_box_labels.repeat(1, repeat)
pipeline_inputs = (
pipeline_images,
pipeline_token_ids,
pipeline_img_ids,
pipeline_text_ids,
pipeline_box_embeddings,
pipeline_box_mask,
pipeline_box_labels,
)
pipeline_inputs = inputs

def pipeline_fn():
pipeline_module(*pipeline_inputs)
Expand Down Expand Up @@ -346,9 +313,7 @@ def pipeline_fn():
decoder_prompt = decoder_prompt.repeat(1, repeat, 1)
decoder_prompt_mask = decoder_prompt_mask.repeat(repeat, 1)
decoder_valid_ratios = decoder_valid_ratios.repeat(repeat, 1, 1)
decoder_backbone_fpn = [
feat.repeat(repeat, 1, 1, 1) for feat in decoder_backbone_fpn
]
decoder_backbone_fpn = [feat.repeat(repeat, 1, 1, 1) for feat in decoder_backbone_fpn]
decoder_only_inputs = (
decoder_backbone_fpn,
decoder_img_ids,
Expand Down
6 changes: 2 additions & 4 deletions scripts/benchmark_sam3_export_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main() -> None:
model.eval()

image = _prepare_image(_load_image(args.image, device), size=1008)
inputs = _make_inputs(1, 1008, 1008, str(device), num_boxes=1)
inputs = _make_inputs(1, 1008, 1008, str(device))

decoder_inputs = None
decoder_inputs_error = None
Expand Down Expand Up @@ -120,9 +120,7 @@ def export_encoder_fusion():
EncoderFusionWrapper(model.transformer.encoder).to(img_feats.device).eval()
)
if args.num_feature_levels != 1:
raise RuntimeError(
"encoder_fusion export currently expects num_feature_levels=1"
)
raise RuntimeError("encoder_fusion export currently expects num_feature_levels=1")
torch.export.export(
encoder_wrapper,
(img_feats, img_pos, img_mask, text_memory, text_attention_mask),
Expand Down
17 changes: 17 additions & 0 deletions scripts/compile_sam3_aoti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import argparse
import torch
import torchvision.ops # noqa: F401
from torch._inductor import config as inductor_config


torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

inductor_config.split_reductions = False

exported = torch.export.load("artifacts/export/full_sam3_pipeline.pt2")
torch._inductor.aoti_compile_and_package(
exported,
package_path="artifacts/aoti/full_sam3_pipeline_aoti.pt2",
)
24 changes: 2 additions & 22 deletions scripts/export_sam3_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,13 @@ def _prepare_image(image: torch.Tensor, size: int) -> torch.Tensor:

def _make_inputs(model, image: torch.Tensor, prompts):
device = image.device
num_prompts = len(prompts)
num_images = int(image.shape[0])

tokenizer = model.backbone.language_backbone.tokenizer
token_ids = tokenizer(prompts, context_length=32).to(device)

img_ids = torch.arange(num_images, device=device, dtype=torch.long)
img_ids = img_ids.repeat_interleave(num_prompts)
text_ids = torch.arange(num_prompts, device=device, dtype=torch.long)
text_ids = text_ids.repeat(num_images)

box_embeddings = torch.zeros(1, num_prompts, 4, device=device)
box_mask = torch.zeros(num_prompts, 1, device=device, dtype=torch.bool)
box_labels = torch.zeros(1, num_prompts, device=device, dtype=torch.long)

return (
image,
token_ids,
img_ids,
text_ids,
box_embeddings,
box_mask,
box_labels,
)


Expand Down Expand Up @@ -101,9 +85,7 @@ def main() -> None:
if not prompts:
raise ValueError("Provide at least one prompt")

model = build_sam3_image_model(
device=args.device, eval_mode=True, enable_segmentation=True
)
model = build_sam3_image_model(device=args.device, eval_mode=True, enable_segmentation=True)
model.eval()

image = _load_image(args.image, torch.device(args.device))
Expand Down Expand Up @@ -139,9 +121,7 @@ def main() -> None:
img_pos = img_pos.repeat(prompt_batch, 1, 1, 1)
img_mask = img_mask.repeat(prompt_batch, 1, 1)

encoder_wrapper = (
EncoderFusionWrapper(model.transformer.encoder).to(img_feats.device).eval()
)
encoder_wrapper = EncoderFusionWrapper(model.transformer.encoder).to(img_feats.device).eval()
encoder = torch.export.export(
encoder_wrapper,
(img_feats, img_pos, img_mask, prompt, prompt_mask),
Expand Down
68 changes: 20 additions & 48 deletions scripts/export_sam3_full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from PIL import Image


REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

Expand All @@ -25,13 +26,22 @@ def forward(
self,
images: torch.Tensor,
token_ids: torch.Tensor,
img_ids: torch.Tensor,
text_ids: torch.Tensor,
box_embeddings: torch.Tensor,
box_mask: torch.Tensor,
box_labels: torch.Tensor,
):
model = cast(Any, self.model)
num_images = images.shape[0]
num_prompts = token_ids.shape[0]
device = images.device
bs = num_images * num_prompts

img_ids = torch.arange(num_images, device=device, dtype=torch.long)
img_ids = img_ids.repeat_interleave(num_prompts)
text_ids = torch.arange(num_prompts, device=device, dtype=torch.long)
text_ids = text_ids.repeat(num_images)

box_embeddings = torch.zeros(1, bs, 4, device=device)
box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool)
box_labels = torch.zeros(1, bs, device=device, dtype=torch.long)

backbone_out = model.backbone.forward_image(images)
text_encoder = model.backbone.language_backbone
_, text_tokens = text_encoder.encoder(token_ids)
Expand All @@ -48,12 +58,8 @@ def forward(
input_boxes=box_embeddings,
input_boxes_mask=box_mask,
input_boxes_label=box_labels,
input_points=torch.zeros(
0, int(token_ids.shape[0]), 2, device=images.device
),
input_points_mask=torch.zeros(
int(token_ids.shape[0]), 0, device=images.device, dtype=torch.bool
),
input_points=torch.zeros(0, bs, 2, device=device),
input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool),
)
geometric_prompt = Prompt(
box_embeddings=box_embeddings,
Expand Down Expand Up @@ -92,23 +98,10 @@ def _prepare_image(image: torch.Tensor, size: int) -> torch.Tensor:

def _make_inputs(model, image: torch.Tensor, prompts):
device = image.device
num_prompts = len(prompts)
num_images = int(image.shape[0])
token_ids = model.backbone.language_backbone.tokenizer(
prompts, context_length=32
).to(device)
img_ids = torch.arange(num_images, device=device, dtype=torch.long)
img_ids = img_ids.repeat_interleave(num_prompts)
text_ids = torch.arange(num_prompts, device=device, dtype=torch.long)
text_ids = text_ids.repeat(num_images)
token_ids = model.backbone.language_backbone.tokenizer(prompts, context_length=32).to(device)
return (
image,
token_ids,
img_ids,
text_ids,
torch.zeros(1, num_prompts, 4, device=device),
torch.zeros(num_prompts, 1, device=device, dtype=torch.bool),
torch.zeros(1, num_prompts, device=device, dtype=torch.long),
)


Expand Down Expand Up @@ -157,21 +150,14 @@ def main() -> None:
)
model.eval()

image = _prepare_image(
_load_image(args.image, torch.device(args.device)), size=1008
)
image = _prepare_image(_load_image(args.image, torch.device(args.device)), size=1008)
inputs = _make_inputs(model, image, prompts)
wrapper = FullSam3PipelineWrapper(model).to(image.device).eval()
if image.shape[0] < 2:
repeat = 2 // image.shape[0]
export_inputs = (
image.repeat(repeat, 1, 1, 1),
inputs[1].repeat(repeat, 1),
inputs[2].repeat(repeat),
inputs[3].repeat(repeat),
inputs[4].repeat(1, repeat, 1),
inputs[5].repeat(repeat, 1),
inputs[6].repeat(1, repeat),
)
else:
export_inputs = inputs
Expand All @@ -186,23 +172,9 @@ def main() -> None:
3: 1008,
},
"token_ids": {
0: torch.export.Dim.AUTO,
0: torch.export.Dim("num_prompts", min=1),
1: 32,
},
"img_ids": {0: torch.export.Dim.AUTO},
"text_ids": {0: torch.export.Dim.AUTO},
"box_embeddings": {
0: 1,
1: torch.export.Dim.AUTO,
},
"box_mask": {
0: torch.export.Dim.AUTO,
1: 1,
},
"box_labels": {
0: 1,
1: torch.export.Dim.AUTO,
},
},
strict=False,
prefer_deferred_runtime_asserts_over_guards=True,
Expand Down
Loading