Skip to content

Commit

Permalink
support SAM-HQ
Browse files Browse the repository at this point in the history
  • Loading branch information
SpenserCai committed Jun 12, 2023
1 parent 6197446 commit e076da1
Show file tree
Hide file tree
Showing 19 changed files with 2,995 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ Choose one or more of the models below and put them to `${sd-webui}/models/sam`

Three types of SAM models are available. [vit_h](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) is 2.56GB, [vit_l](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) is 1.25GB, [vit_b](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) is 375MB. I myself tested vit_h on NVIDIA 3090 Ti which is good. If you encounter VRAM problem, you should switch to smaller models.

If you want use SAM-HQ,[hq_vit_h](https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing),[hq_vit_l](https://drive.google.com/file/d/1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G/view?usp=sharing),[hq_vit_b](https://drive.google.com/file/d/11yExZLOve38kRZPfRx_MRxfIAKmfMY47/view?usp=sharing)

GroundingDINO packages, GroundingDINO models and ControlNet annotator models will be automatically installed the first time you use them.

## GroundingDINO
Expand Down
20 changes: 18 additions & 2 deletions scripts/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessing
from modules.devices import device, torch_gc, cpu
from modules.paths import models_path
from segment_anything import SamPredictor, sam_model_registry
from segment_anything import SamPredictor as SamPredictorBase, sam_model_registry
from segment_anything_hq import SamPredictor as SamPredictorHQ, sam_model_registry as sam_model_registry_hq
from scripts.dino import dino_model_list, dino_predict_internal, show_boxes, clear_dino_cache, dino_install_issue_text
from scripts.auto import clear_sem_sam_cache, register_auto_sam, semantic_segmentation, sem_sam_garbage_collect, image_layer_internal, categorical_mask_image
from scripts.process_params import SAMProcessUnit, max_cn_num
Expand All @@ -29,6 +30,8 @@
sam_model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile(os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt']
sam_device = device

is_hq = False


txt2img_width: gr.Slider = None
txt2img_height: gr.Slider = None
Expand All @@ -54,6 +57,12 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5):
image[mask] = image[mask] * (1 - alpha) + 255 * color.reshape(1, 1, -1) * alpha
return image.astype(np.uint8)

def SamPredictor(sam_model):
if is_hq:
return SamPredictorHQ(sam_model)
else:
return SamPredictorBase(sam_model)


def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
print("Dilation Amount: ", dilation_amt)
Expand All @@ -71,10 +80,17 @@ def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):


def load_sam_model(sam_checkpoint):
global is_hq
model_type = '_'.join(sam_checkpoint.split('_')[1:-1])
sam_checkpoint = os.path.join(sam_model_dir, sam_checkpoint)
torch.load = unsafe_torch_load
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# 如果包含hq,则使用hq版本的sam
if 'hq' in sam_checkpoint:
sam = sam_model_registry_hq[model_type.replace("hq_","")](checkpoint=sam_checkpoint)
is_hq = True
else:
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
is_hq = False
sam.to(device=sam_device)
sam.eval()
torch.load = load
Expand Down
16 changes: 16 additions & 0 deletions segment_anything_hq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .build_sam import (
build_sam,
build_sam_vit_h,
build_sam_vit_l,
build_sam_vit_b,
sam_model_registry,
)
from .build_sam_baseline import sam_model_registry_baseline
from .predictor import SamPredictor
from .automatic_mask_generator import SamAutomaticMaskGenerator
Loading

0 comments on commit e076da1

Please sign in to comment.