-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SAM
] Fixes pipeline and adds a dummy pipeline test
#23684
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
|
||
import requests | ||
|
||
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig | ||
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline | ||
from transformers.testing_utils import require_torch, slow, torch_device | ||
from transformers.utils import is_torch_available, is_vision_available | ||
|
||
|
@@ -751,3 +751,9 @@ def test_inference_mask_generation_three_boxes_point_batch(self): | |
iou_scores = outputs.iou_scores.cpu() | ||
self.assertTrue(iou_scores.shape == (1, 3, 3)) | ||
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) | ||
|
||
def test_dummy_pipeline_generation(self): | ||
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device) | ||
raw_image = prepare_image() | ||
|
||
_ = generator(raw_image, points_per_batch=64) | ||
Comment on lines
+755
to
+759
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am surprised that this is not decorated by if is_vision_available():
from PIL import Image This has to be fixed, but it's OK if we do this in a separate PR and maybe delegate the task to the original SAM modeling author in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when it currently fails, what is the type of
crop_boxes
? Is it a torch tensor? If so, we can't rely onnp.array
unless we are 100% surecrop_boxes
is never on cuda.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pre-processing should never be done on cuda no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, my bad, didn't realize it's pre instead post processing. And #22970 actually removes the device stuff.