Skip to content

Commit

Permalink
mask_metadata and requirement.txt
Browse files Browse the repository at this point in the history
  • Loading branch information
Andy1621 committed Apr 8, 2023
1 parent db6c43b commit e28455d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
38 changes: 37 additions & 1 deletion grounded_sam_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy

import numpy as np
import json
import torch
from PIL import Image, ImageDraw, ImageFont

Expand Down Expand Up @@ -98,6 +99,35 @@ def show_box(box, ax, label):
ax.text(x0, y0, label)


def save_mask_data(output_dir, mask_list, box_list, label_list):
value = 0 # 0 for background

mask_img = torch.zeros(mask_list.shape[-2:])
for idx, mask in enumerate(mask_list):
mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
plt.figure(figsize=(10, 10))
plt.imshow(mask_img.numpy())
plt.axis('off')
plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)

json_data = [{
'value': value,
'label': 'background'
}]
for label, box in zip(label_list, box_list):
value += 1
name, logit = label.split('(')
logit = logit[:-1] # the last is ')'
json_data.append({
'value': value,
'label': name,
'logit': float(logit),
'box': box.numpy().tolist(),
})
with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
json.dump(json_data, f)


if __name__ == "__main__":

parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
Expand Down Expand Up @@ -176,6 +206,12 @@ def show_box(box, ax, label):
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
show_box(box.numpy(), plt.gca(), label)

plt.axis('off')
plt.savefig(os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight")
plt.savefig(
os.path.join(output_dir, "grounded_sam_output.jpg"),
bbox_inches="tight", dpi=300, pad_inches=0.0
)

save_mask_data(output_dir, masks, boxes_filt, pred_phrases)

20 changes: 20 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
addict
diffusers
gradio
huggingface_hub
matplotlib
numpy
onnxruntime
opencv_python
Pillow
pycocotools
PyYAML
requests
setuptools
supervision
termcolor
timm
torch
torchvision
transformers
yapf

2 comments on commit e28455d

@stevezkw1998
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Andy1621
In my humble oppions, adding specific version is more suitable, because torch/torchvison 's default version is CPU, but when you use CPU you will face some issues in this app
I suggest to use:
torch==2.0.0+cu117
torchvision==0.15.1+cu117
Or adding some 说明 in README.md

@Andy1621
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I do not add specific versions since either pytorch>=1.7 and torchvision>=0.8 is ok. Just make sure to add cuda version when install torch-related packages.

Please sign in to comment.