Skip to content

Commit

Permalink
MobileSAM
Browse files Browse the repository at this point in the history
  • Loading branch information
continue-revolution committed Jun 28, 2023
1 parent 780fc49 commit 11f2bf5
Show file tree
Hide file tree
Showing 4 changed files with 685 additions and 13 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ This extension aim for connecting [AUTOMATIC1111 Stable Diffusion WebUI](https:/
- `2023/05/29`: [v1.4.2](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.4.2) You may now do SAM inference on CPU by checking "Use CPU for SAM". This is for some MAC users who are not able to do SAM inference on GPU. I discourage other users from using this feature because it is significantly slower than CUDA.
- `2023/06/01`: [v1.5.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.5.0) You may now choose to use local GroundingDINO to bypass C++ problem. See [FAQ](#faq)-1 for more detail.
- `2023/06/04`: [v1.5.1](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.5.1) `Upload Mask to ControlNet Inpainting` comes back in response to [ControlNet inpaint improvement](https://github.com/Mikubill/sd-webui-controlnet/discussions/1464). You should see a new tab beside `AutoSAM` after updating the extension. This feature will again be removed once ControlNet extension has its own uploading feature.
- `2023/06/13`: [v1.6.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.6.0) [SAM-HQ](https://github.com/SysCV/sam-hq) supported by [@SpenserCai](https://github.com/SpenserCai) and me. This is an "upgraded" SAM from researchers at ETH Zurich & HKUST. However, I cannot guarantee which one is better and you should make your own choice based on your own experiments. Go to [Installation](#installation) to get the link to the models.
- `2023/06/13`: [v1.6.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.6.0) [SAM-HQ](https://github.com/SysCV/sam-hq) supported by [@SpenserCai](https://github.com/SpenserCai) and me. This is an "upgraded" SAM, created by researchers at ETH Zurich & HKUST. However, I cannot guarantee which one is better and you should make your own choice based on your own experiments. Go to [Installation](#installation) to get the link to the models.
- `2023/06/29`: [v1.6.1](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.6.1) [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) supported. This is a tiny version of SAM, created by researchers at Kyung Hee University. Visit [here](https://github.com/continue-revolution/sd-webui-segment-anything/issues/139) for more information.

Note that support for some other variations of SAM, such as [Matting-Anything](https://github.com/SHI-Labs/Matting-Anything) and [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) are still on the way. Support for these models, unlike MobileSAM, are non-trivial, especially FastSAM, which utilize a completely different pipeline, ultralytics/YOLO. Introducing these new works to the current codebase will make the original ugly-enough codebase more ugly. They will be supported once I finish a major refactor of the current codebase.

## FAQ

Expand Down
73 changes: 62 additions & 11 deletions sam_hq/build_sam_hq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from .modeling.mask_decoder_hq import MaskDecoderHQ
from .modeling.image_encoder import ImageEncoderViTHQ
from segment_anything.modeling import PromptEncoder, Sam, TwoWayTransformer
from .modeling.tiny_vit import TinyViT
from segment_anything.modeling import PromptEncoder, Sam, TwoWayTransformer, MaskDecoder
from segment_anything import build_sam_vit_h, build_sam_vit_l, build_sam_vit_b


Expand Down Expand Up @@ -44,16 +45,32 @@ def build_sam_hq_vit_b(checkpoint=None):
)


def build_mobile_sam(checkpoint=None):
return _build_mobile_sam(checkpoint)


sam_model_registry = {
"sam_vit_h": build_sam_vit_h,
"sam_vit_l": build_sam_vit_l,
"sam_vit_b": build_sam_vit_b,
"sam_hq_vit_h": build_sam_hq_vit_h,
"sam_hq_vit_l": build_sam_hq_vit_l,
"sam_hq_vit_b": build_sam_hq_vit_b,
"mobile_sam": build_mobile_sam,
}


def _load_sam_checkpoint(sam: Sam, checkpoint=None):
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
info = sam.load_state_dict(state_dict, strict=False)
print(info)
for _, p in sam.named_parameters():
p.requires_grad = False
return sam

def _build_sam_hq(
encoder_embed_dim,
encoder_depth,
Expand Down Expand Up @@ -102,14 +119,48 @@ def _build_sam_hq(
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
info = sam.load_state_dict(state_dict, strict=False)
print(info)
for n, p in sam.named_parameters():
if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
p.requires_grad = False
return _load_sam_checkpoint(sam, checkpoint)

return sam

def _build_mobile_sam(checkpoint=None):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
mobile_sam = Sam(
image_encoder=TinyViT(
img_size=1024, in_chans=3, num_classes=1000,
embed_dims=[64, 128, 160, 320],
depths=[2, 2, 6, 2],
num_heads=[2, 4, 5, 10],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
return _load_sam_checkpoint(mobile_sam, checkpoint)
Loading

0 comments on commit 11f2bf5

Please sign in to comment.