Skip to content

Commit

Permalink
local groundingdino
Browse files Browse the repository at this point in the history
  • Loading branch information
continue-revolution committed Jun 1, 2023
1 parent e20d3c6 commit df2eee1
Show file tree
Hide file tree
Showing 26 changed files with 2,840 additions and 42 deletions.
27 changes: 16 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@ This extension aim for connecting [AUTOMATIC1111 Stable Diffusion WebUI](https:/

## News

- `2023/04/10`: [Release] SAM extension released! You can click on the image to generate segmentation masks.
- `2023/04/12`: [Feature] Mask expansion released by [@jordan-barrett-jm](https://github.com/jordan-barrett-jm)! You can expand masks to overcome edge problems of SAM.
- `2023/04/15`: [Feature] [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) support released! You can enter text prompts to generate bounding boxes and segmentation masks.
- `2023/04/15`: [Feature] API support released by [@jordan-barrett-jm](https://github.com/jordan-barrett-jm)!
- `2023/04/18`: [Feature] [ControlNet V1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly) inpainting support released! You can copy SAM generated masks to ControlNet to do inpainting. Note that you **must** update [ControlNet extension](https://github.com/Mikubill/sd-webui-controlnet) to use it. ControlNet inpainting has far better performance compared to general-purposed models, and you do not need to download inpainting-specific models anymore.
- `2023/04/24`: [Feature] Automatic segmentation support released! Functionalities with * require you to have [ControlNet extension](https://github.com/Mikubill/sd-webui-controlnet) installed. Last commit: `724b4db`. This update includes support for
- `2023/04/10`: [v1.0.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.0.0) SAM extension released! You can click on the image to generate segmentation masks.
- `2023/04/12`: [v1.0.1](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.0.1) Mask expansion released by [@jordan-barrett-jm](https://github.com/jordan-barrett-jm)! You can expand masks to overcome edge problems of SAM.
- `2023/04/15`: [v1.1.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.1.0) [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) support released! You can enter text prompts to generate bounding boxes and segmentation masks.
- `2023/04/15`: [v1.2.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.2.0) API support released by [@jordan-barrett-jm](https://github.com/jordan-barrett-jm)!
- `2023/04/18`: [v1.3.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.3.0) [ControlNet V1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly) inpainting support released! You can copy SAM generated masks to ControlNet to do inpainting. Note that you **must** update [ControlNet extension](https://github.com/Mikubill/sd-webui-controlnet) to use it. ControlNet inpainting has far better performance compared to general-purposed models, and you do not need to download inpainting-specific models anymore.
- `2023/04/24`: [v1.4.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.4.0) Automatic segmentation support released! Functionalities with * require you to have [ControlNet extension](https://github.com/Mikubill/sd-webui-controlnet) installed. Last commit: `724b4db`. This update includes support for
- *[ControlNet V1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly) semantic segmentation
- [EditAnything](https://github.com/sail-sg/EditAnything) un-semantic segmentation
- Image layout generation (single image + batch process)
- *Image masking with categories (single image + batch process)
- *Inpaint not masked for ControlNet inpainting on txt2img panel
- `2023/04/29`: [Feature] API has been completely refactored. You can access all features for **single image process** through API. API documentation has been moved to [wiki](https://github.com/continue-revolution/sd-webui-segment-anything/wiki/API).
- `2023/05/22`: [Feature] [EditAnything](https://github.com/sail-sg/EditAnything) is ready to use! You can generate random segmentation and copy the output to EditAnything ControlNet.
- `2023/05/29`: [Feature] You may now do SAM inference on CPU. This is for some MAC users who are not able to do SAM inference on GPU. I discourage other users from using this feature because it is significantly slower than CUDA. Last commit: `89a2213`.
- `2023/04/29`: [v1.4.1](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.4.1) API has been completely refactored. You can access all features for **single image process** through API. API documentation has been moved to [wiki](https://github.com/continue-revolution/sd-webui-segment-anything/wiki/API).
- `2023/05/22`: [v1.4.2](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.4.2) [EditAnything](https://github.com/sail-sg/EditAnything) is ready to use! You can generate random segmentation and copy the output to EditAnything ControlNet.
- `2023/05/29`: [v1.4.3](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.4.3) You may now do SAM inference on CPU by checking "Use CPU for SAM". This is for some MAC users who are not able to do SAM inference on GPU. I discourage other users from using this feature because it is significantly slower than CUDA. Last commit: `89a2213`.
- `2023/06/01`: [v1.5.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.5.0) You may now choose to use local GroundingDINO to bypass C++ problem. See [FAQ](#faq)-1 for more detail.

## TODO

Expand All @@ -36,7 +37,11 @@ There are already at least two great tutorials on how to use this extension. Che

You should know the following before submitting an issue.

1. I observe some common problems for Windows users:
1. Due to the overwhemling complaints about GroundingDINO installment and the lack of substitution of similar high-performance text-to-bounding-box library, I decide to modify the source code of GroundingDINO and push to this repository. Starting from [v1.5.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.5.0), you can choose to use local GroundingDINO by checking `Use local groundingdino to bypass C++ problem` on `Settings/Segment Anything`. This change should solve all problems about ninja, pycocotools, _C and any other problems related to C++/CUDA compilation.

If you did not modify the setting described above, This script will firstly try to install GroundingDINO and check if your environment has successfully built the C++ dynamic library (the annoying `_C`). If so, this script will use the official implementation of GroundingDINO. This is to show respect to the authors of GroundingDINO. If the script failed to install GroundingDINO, it will use local GroundingDINO instead.

If you'd still like to resolve the install problem of GroundingDINO, I observe some common problems for Windows users:
- `pycocotool`: [here](https://github.com/cocodataset/cocoapi/issues/415#issuecomment-627313816).
- `_C`: [here](https://github.com/continue-revolution/sd-webui-segment-anything/issues/32#issuecomment-1513873296). DO NOT skip steps.

Expand Down Expand Up @@ -80,7 +85,7 @@ GroundingDINO has been supported in this extension. It has the following functio

However, there are some existing problems with GroundingDINO:
1. GroundingDINO will be install when you firstly use GroundingDINO features, instead of when you initiate the WebUI. Make sure that your terminal can have access to GitHub, otherwise you have to install GroundingDINO manually. GroundingDINO models will be automatically downloaded from [huggingFace](https://huggingface.co/ShilongLiu/GroundingDINO/tree/main). If your terminal cannot visit HuggingFace, please manually download the model and put it under `${sd-webui-sam}/models/grounding-dino`.
2. GroundingDINO requires your device to compile C++, which might take a long time and throw tons of exceptions. If you encounter `_C` problem, it's most probably because you did not install CUDA Toolkit. Follow steps decribed [here](https://github.com/continue-revolution/sd-webui-segment-anything/issues/32#issuecomment-1513873296). DO NOT skip steps. Otherwise, please go to [Grounded-SAM Issue Page](https://github.com/IDEA-Research/Grounded-Segment-Anything/issues) and submit an issue there. Despite of this, you can still use this extension for point prompts->segmentation masks even if you cannot install GroundingDINO, don't worry.
2. **If you want to use local groundingdino to bypass ALL the painful C++/CUDA/ninja/pycocotools problems, please read [FAQ](#faq)-1.** GroundingDINO requires your device to compile C++, which might take a long time and throw tons of exceptions. If you encounter `_C` problem, it's most probably because you did not install CUDA Toolkit. Follow steps decribed [here](https://github.com/continue-revolution/sd-webui-segment-anything/issues/32#issuecomment-1513873296). DO NOT skip steps. Otherwise, please go to [Grounded-SAM Issue Page](https://github.com/IDEA-Research/Grounded-Segment-Anything/issues) and submit an issue there. Despite of this, you can still use this extension for point prompts->segmentation masks even if you cannot install GroundingDINO, don't worry.
3. If you want to use point prompts, SAM can at most accept one bounding box. This extension will check if there are multiple bounding boxes. If multiple bounding boxes, this extension will disgard all point prompts; otherwise all point prompts will be effective. You may always select one bounding box you want.

For more detail, check [How to Use](#how-to-use) and [Demo](#demo).
Expand Down
Empty file.
Binary file not shown.
Binary file not shown.
311 changes: 311 additions & 0 deletions local_groundingdino/datasets/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Transforms and data augmentation for both image + bbox.
"""
import os
import random

import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F

from local_groundingdino.util.box_ops import box_xyxy_to_cxcywh
from local_groundingdino.util.misc import interpolate


def crop(image, target, region):
cropped_image = F.crop(image, *region)

target = target.copy()
i, j, h, w = region

# should we do something wrt the original size?
target["size"] = torch.tensor([h, w])

fields = ["labels", "area", "iscrowd", "positive_map"]

if "boxes" in target:
boxes = target["boxes"]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
target["boxes"] = cropped_boxes.reshape(-1, 4)
target["area"] = area
fields.append("boxes")

if "masks" in target:
# FIXME should we update the area here if there are no boxes?
target["masks"] = target["masks"][:, i : i + h, j : j + w]
fields.append("masks")

# remove elements for which the boxes or masks that have zero area
if "boxes" in target or "masks" in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if "boxes" in target:
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target["masks"].flatten(1).any(1)

for field in fields:
if field in target:
target[field] = target[field][keep]

if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
# for debug and visualization only.
if "strings_positive" in target:
target["strings_positive"] = [
_i for _i, _j in zip(target["strings_positive"], keep) if _j
]

return cropped_image, target


def hflip(image, target):
flipped_image = F.hflip(image)

w, h = image.size

target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
[w, 0, w, 0]
)
target["boxes"] = boxes

if "masks" in target:
target["masks"] = target["masks"].flip(-1)

return flipped_image, target


def resize(image, target, size, max_size=None):
# size can be min_size (scalar) or (w, h) tuple

def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))

if (w <= h and w == size) or (h <= w and h == size):
return (h, w)

if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)

return (oh, ow)

def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)

size = get_size(image.size, size, max_size)
rescaled_image = F.resize(image, size)

if target is None:
return rescaled_image, None

ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios

target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height]
)
target["boxes"] = scaled_boxes

if "area" in target:
area = target["area"]
scaled_area = area * (ratio_width * ratio_height)
target["area"] = scaled_area

h, w = size
target["size"] = torch.tensor([h, w])

if "masks" in target:
target["masks"] = (
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
)

return rescaled_image, target


def pad(image, target, padding):
# assumes that we only pad on the bottom right corners
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
if target is None:
return padded_image, None
target = target.copy()
# should we do something wrt the original size?
target["size"] = torch.tensor(padded_image.size[::-1])
if "masks" in target:
target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
return padded_image, target


class ResizeDebug(object):
def __init__(self, size):
self.size = size

def __call__(self, img, target):
return resize(img, target, self.size)


class RandomCrop(object):
def __init__(self, size):
self.size = size

def __call__(self, img, target):
region = T.RandomCrop.get_params(img, self.size)
return crop(img, target, region)


class RandomSizeCrop(object):
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
# respect_boxes: True to keep all boxes
# False to tolerence box filter
self.min_size = min_size
self.max_size = max_size
self.respect_boxes = respect_boxes

def __call__(self, img: PIL.Image.Image, target: dict):
init_boxes = len(target["boxes"])
max_patience = 10
for i in range(max_patience):
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w])
result_img, result_target = crop(img, target, region)
if (
not self.respect_boxes
or len(result_target["boxes"]) == init_boxes
or i == max_patience - 1
):
return result_img, result_target
return result_img, result_target


class CenterCrop(object):
def __init__(self, size):
self.size = size

def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))


class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p

def __call__(self, img, target):
if random.random() < self.p:
return hflip(img, target)
return img, target


class RandomResize(object):
def __init__(self, sizes, max_size=None):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.max_size = max_size

def __call__(self, img, target=None):
size = random.choice(self.sizes)
return resize(img, target, size, self.max_size)


class RandomPad(object):
def __init__(self, max_pad):
self.max_pad = max_pad

def __call__(self, img, target):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
return pad(img, target, (pad_x, pad_y))


class RandomSelect(object):
"""
Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2
"""

def __init__(self, transforms1, transforms2, p=0.5):
self.transforms1 = transforms1
self.transforms2 = transforms2
self.p = p

def __call__(self, img, target):
if random.random() < self.p:
return self.transforms1(img, target)
return self.transforms2(img, target)


class ToTensor(object):
def __call__(self, img, target):
return F.to_tensor(img), target


class RandomErasing(object):
def __init__(self, *args, **kwargs):
self.eraser = T.RandomErasing(*args, **kwargs)

def __call__(self, img, target):
return self.eraser(img), target


class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std

def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
h, w = image.shape[-2:]
if "boxes" in target:
boxes = target["boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["boxes"] = boxes
return image, target


class Compose(object):
def __init__(self, transforms):
self.transforms = transforms

def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target

def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
18 changes: 18 additions & 0 deletions local_groundingdino/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .GroundingDINO import build_groundingdino


def build_model(args):
# we use register to maintain models from catdet6 on.
from .registry import MODULE_BUILD_FUNCS

assert args.modelname in MODULE_BUILD_FUNCS._module_dict
build_func = MODULE_BUILD_FUNCS.get(args.modelname)
model = build_func(args)
return model
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit df2eee1

Please sign in to comment.