-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
573 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
## Grounded-Fast-SAM | ||
|
||
Combining [Grounding-DINO](https://github.com/IDEA-Research/GroundingDINO) and [Fast-SAM](https://github.com/CASIA-IVA-Lab/FastSAM) for faster zero-shot detect and segment anything. | ||
|
||
|
||
### Contents | ||
- [Installation](#installation) | ||
|
||
|
||
### Installation | ||
|
||
- Install [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything#installation) | ||
|
||
- Install [Fast-SAM](https://github.com/CASIA-IVA-Lab/FastSAM#installation) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import argparse | ||
import cv2 | ||
from ultralytics import YOLO | ||
from tools import * | ||
from groundingdino.util.inference import load_model, load_image, predict, annotate, Model | ||
from torchvision.ops import box_convert | ||
import ast | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model_path", type=str, default="/comp_robot/rentianhe/code/Grounded-Segment-Anything/FastSAM/FastSAM-x.pt", help="model" | ||
) | ||
parser.add_argument( | ||
"--img_path", type=str, default="./images/dogs.jpg", help="path to image file" | ||
) | ||
parser.add_argument( | ||
"--text", type=str, default="the black dog.", help="text prompt for GroundingDINO" | ||
) | ||
parser.add_argument("--imgsz", type=int, default=1024, help="image size") | ||
parser.add_argument( | ||
"--iou", | ||
type=float, | ||
default=0.9, | ||
help="iou threshold for filtering the annotations", | ||
) | ||
parser.add_argument( | ||
"--conf", type=float, default=0.4, help="object confidence threshold" | ||
) | ||
parser.add_argument( | ||
"--output", type=str, default="./output/", help="image save path" | ||
) | ||
parser.add_argument( | ||
"--randomcolor", type=bool, default=True, help="mask random color" | ||
) | ||
parser.add_argument( | ||
"--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]" | ||
) | ||
parser.add_argument( | ||
"--point_label", | ||
type=str, | ||
default="[0]", | ||
help="[1,0] 0:background, 1:foreground", | ||
) | ||
parser.add_argument("--box_prompt", type=str, default="[0,0,0,0]", help="[x,y,w,h]") | ||
parser.add_argument( | ||
"--better_quality", | ||
type=str, | ||
default=False, | ||
help="better quality using morphologyEx", | ||
) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
parser.add_argument( | ||
"--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu" | ||
) | ||
parser.add_argument( | ||
"--retina", | ||
type=bool, | ||
default=True, | ||
help="draw high-resolution segmentation masks", | ||
) | ||
parser.add_argument( | ||
"--withContours", type=bool, default=False, help="draw the edges of the masks" | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def main(args): | ||
|
||
# Image Path | ||
img_path = args.img_path | ||
text = args.text | ||
|
||
# path to save img | ||
save_path = args.output | ||
if not os.path.exists(save_path): | ||
os.makedirs(save_path) | ||
basename = os.path.basename(args.img_path).split(".")[0] | ||
|
||
# Build Fast-SAM Model | ||
# ckpt_path = "/comp_robot/rentianhe/code/Grounded-Segment-Anything/FastSAM/FastSAM-x.pt" | ||
model = YOLO(args.model_path) | ||
|
||
results = model( | ||
args.img_path, | ||
imgsz=args.imgsz, | ||
device=args.device, | ||
retina_masks=args.retina, | ||
iou=args.iou, | ||
conf=args.conf, | ||
max_det=100, | ||
) | ||
|
||
|
||
# Build GroundingDINO Model | ||
groundingdino_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | ||
groundingdino_ckpt_path = "./groundingdino_swint_ogc.pth" | ||
|
||
image_source, image = load_image(img_path) | ||
model = load_model(groundingdino_config, groundingdino_ckpt_path) | ||
|
||
boxes, logits, phrases = predict( | ||
model=model, | ||
image=image, | ||
caption=text, | ||
box_threshold=0.3, | ||
text_threshold=0.25, | ||
device=args.device, | ||
) | ||
|
||
|
||
# Grounded-Fast-SAM | ||
|
||
ori_img = cv2.imread(img_path) | ||
ori_h = ori_img.shape[0] | ||
ori_w = ori_img.shape[1] | ||
|
||
# x = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").cpu().numpy().tolist() | ||
# import pdb; pdb.set_trace() | ||
|
||
boxes = boxes * torch.Tensor([ori_w, ori_h, ori_w, ori_h]) | ||
print(f"Detected Boxes: {len(boxes)}") | ||
boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").cpu().numpy().tolist() | ||
for box_idx in range(len(boxes)): | ||
mask, _ = box_prompt( | ||
results[0].masks.data, | ||
boxes[box_idx], | ||
ori_h, | ||
ori_w, | ||
) | ||
annotations = np.array([mask]) | ||
img_array = fast_process( | ||
annotations=annotations, | ||
args=args, | ||
mask_random_color=True, | ||
bbox=boxes[box_idx], | ||
) | ||
cv2.imwrite(os.path.join(save_path, basename + f"_{str(box_idx)}_caption_{phrases[box_idx]}.jpg"), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
Oops, something went wrong.