Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM2 more than one point and label error #3484

Open
canglaoshidaidui opened this issue Sep 27, 2024 · 2 comments
Open

SAM2 more than one point and label error #3484

canglaoshidaidui opened this issue Sep 27, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@canglaoshidaidui
Copy link

Description

@ dear DJL Team
SAM2 more than one point and label error

Expected Behavior

 public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
    String url =
            "https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg";
    //Sam2Input input = Sam2Input.newInstance(url, 500, 375);

    Image image = ImageFactory.getInstance().fromUrl(url);
    //Sam2Input input  = new Sam2Input(image, Arrays.asList(new Point(500,375),new Point(1000,375)),Arrays.asList(1,0));
    //Sam2Input input  = new Sam2Input(image, Arrays.asList(new Point(1000,375)),Arrays.asList(1));
    //Sam2Input input  = new Sam2Input(image, Arrays.asList(new Point(1700,1100)),Arrays.asList(1));
    Sam2Input input  = new Sam2Input(image, Arrays.asList(new Point(500,375),new Point(1100,600)),Arrays.asList(1,1));



    Criteria<Sam2Input, DetectedObjects> criteria =
            Criteria.builder()
                    .setTypes(Sam2Input.class, DetectedObjects.class)
                    .optModelUrls("djl://ai.djl.pytorch/sam2-hiera-tiny")
                    //.optModelPath(modelPath)
                    .optEngine("PyTorch")
                    .optDevice(Device.cpu()) // use sam2-hiera-tiny-gpu for GPU
                    .optTranslator(new Sam2Translator())
                    .optProgress(new ProgressBar())
                    .build();
    try (ZooModel<Sam2Input, DetectedObjects> model = criteria.loadModel();
         Predictor<Sam2Input, DetectedObjects> predictor = model.newPredictor()) {
        DetectedObjects detection = predictor.predict(input);
        showMask(input, detection);
        return detection;
    }
}

Error Message

Exception in thread "main" ai.djl.translate.TranslateException: ai.djl.engine.EngineException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/torch.py", line 63, in forward
feat_s1 = torch.view(torch.permute(feat2, [1, 2, 0]), [1, -1, 128, 128])
feat_s0 = torch.view(torch.permute(feat, [1, 2, 0]), [1, -1, 256, 256])
_23 = (sam_prompt_encoder0).forward(point_coords, point_labels, )
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_24, _25, = _23
image_embeddings = torch.unsqueeze(torch.select(image_embed, 0, 0), 0)
File "code/torch/sam2/modeling/sam/prompt_encoder.py", line 78, in forward
_32 = annotate(List[Optional[Tensor]], [_31])
33 = torch.add(torch.index(point_embedding2, _32), weight2)
_34 = torch.view(_33, [256])
~~~~~~~~~~ <--- HERE
_35 = annotate(List[Optional[Tensor]], [31])
point_embedding3 = torch.index_put
(point_embedding2, _35, _34)

Traceback of TorchScript, original code (most recent call last):
/Users/lufen/source/venv/lib/python3.11/site-packages/sam2/modeling/sam/prompt_encoder.py(98): _embed_points
/Users/lufen/source/venv/lib/python3.11/site-packages/sam2/modeling/sam/prompt_encoder.py(169): forward
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1543): _slow_forward
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1562): _call_impl
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
/Users/lufen/source/ptest/p_sam2/trace_sam2_img.py(74): predict
/Users/lufen/source/ptest/p_sam2/trace_sam2_img.py(62): forward
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1543): _slow_forward
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1562): _call_impl
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/jit/_trace.py(1275): trace_module
/Users/lufen/source/ptest/p_sam2/trace_sam2_img.py(104): trace_model
/Users/lufen/source/ptest/p_sam2/trace_sam2_img.py(111):
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py(18): execfile
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py(1535): _exec
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py(1528): run
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py(2218): main
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py(2236):
RuntimeError: shape '[256]' is invalid for input of size 512

at ai.djl.inference.Predictor.batchPredict(Predictor.java:197)
at ai.djl.inference.Predictor.predict(Predictor.java:133)
at SegmentAnything2.predict(SegmentAnything2.java:79)
at SegmentAnything2.main(SegmentAnything2.java:49)

Caused by: ai.djl.engine.EngineException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/torch.py", line 63, in forward
feat_s1 = torch.view(torch.permute(feat2, [1, 2, 0]), [1, -1, 128, 128])
feat_s0 = torch.view(torch.permute(feat, [1, 2, 0]), [1, -1, 256, 256])
_23 = (sam_prompt_encoder0).forward(point_coords, point_labels, )
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_24, _25, = _23
image_embeddings = torch.unsqueeze(torch.select(image_embed, 0, 0), 0)
File "code/torch/sam2/modeling/sam/prompt_encoder.py", line 78, in forward
_32 = annotate(List[Optional[Tensor]], [_31])
33 = torch.add(torch.index(point_embedding2, _32), weight2)
_34 = torch.view(_33, [256])
~~~~~~~~~~ <--- HERE
_35 = annotate(List[Optional[Tensor]], [31])
point_embedding3 = torch.index_put
(point_embedding2, _35, _34)

Traceback of TorchScript, original code (most recent call last):
/Users/lufen/source/venv/lib/python3.11/site-packages/sam2/modeling/sam/prompt_encoder.py(98): _embed_points
/Users/lufen/source/venv/lib/python3.11/site-packages/sam2/modeling/sam/prompt_encoder.py(169): forward
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1543): _slow_forward
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1562): _call_impl
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
/Users/lufen/source/ptest/p_sam2/trace_sam2_img.py(74): predict
/Users/lufen/source/ptest/p_sam2/trace_sam2_img.py(62): forward
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1543): _slow_forward
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1562): _call_impl
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
/Users/lufen/source/venv/lib/python3.11/site-packages/torch/jit/_trace.py(1275): trace_module
/Users/lufen/source/ptest/p_sam2/trace_sam2_img.py(104): trace_model
/Users/lufen/source/ptest/p_sam2/trace_sam2_img.py(111):
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py(18): execfile
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py(1535): _exec
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py(1528): run
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py(2218): main
/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py(2236):
RuntimeError: shape '[256]' is invalid for input of size 512

at ai.djl.pytorch.jni.PyTorchLibrary.moduleRunMethod(Native Method)
at ai.djl.pytorch.jni.IValueUtils.forward(IValueUtils.java:57)
at ai.djl.pytorch.engine.PtSymbolBlock.forwardInternal(PtSymbolBlock.java:146)
at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79)
at ai.djl.nn.Block.forward(Block.java:127)
at ai.djl.inference.Predictor.predictInternal(Predictor.java:147)
at ai.djl.inference.Predictor.batchPredict(Predictor.java:172)
... 3 more
@canglaoshidaidui canglaoshidaidui added the bug Something isn't working label Sep 27, 2024
@frankfliu
Copy link
Contributor

@canglaoshidaidui
This is a limitation of traced model. During the jit trace, the input shape is fixed. Currently the model is traced with single point and doesn't support box. You have to manually trace the model with 2 point for your use case.

@frankfliu
Copy link
Contributor

@canglaoshidaidui
There actually a way to resolve this issue:

  1. trace the model into encoder and decoder model instead of single model
  2. Load the encoder mode as part of Translator
  3. Based on input points generator different input for decoder model in Translator

Let me see if I can improve it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants