Skip to content

Commit dd5c7b0

Browse files
authored
Fix AO SAM2 issues (#2109)
Fix AO SAM2 issues (#2109) Summary: Pull Request resolved: #2109 SAM2 issues - Whenever ```clear_old_points``` was enabled SAM2 would crash AAS Track mult issues - Enables ```multimask``` flags Rootcaused issues to failed assertion in the following lines in ```sam2_base.py::_track_step:L788```: ``` if prev_sam_mask_logits is not None: assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) assert mask_inputs is None ``` Whenever ```prev_sam_mask_logits``` has a value it results in a crash. There are several situations where this is expected to be the case including during streamed runs, or when clearing points. Test Plan: aistudio test local aas_track_mult ``` Retrieving package values for `fbcode//ai_demos/server_model_zoo/models/aas_track_mult`: buck2 audit package-values --reuse-current-config fbcode//ai_demos/server_model_zoo/models/aas_track_mult Buck command to find test owners: buck2 uquery --reuse-current-config owner(/data/sandcastle/boxes/fbsource/fbcode/ai_demos/server_model_zoo/models/aas_track_mult/test_aas_track_mult_model.py) -a labels Buck command to invoke a test: buck2 test --reuse-current-config --write-build-id /tmp/.tmpS35tJk --client-metadata language=python --client-metadata id=testify.codelens --client-metadata session_id=d0229502-10cc-45e7-a6f6-6c5c276c2e17 fbcode//ai_demos/server_model_zoo/models/aas_track_mult:tests -- --regex ai_demos/server_model_zoo/models/aas_track_mult:tests \- .*(?:\(.*TestAasTrackMultModel\)$|TestAasTrackMultModel: .*) --run-disabled Buck UI: https://www.internalfb.com/buck2/bf9cbfaa-ae6a-4568-876c-0b128dd474bd Test UI: https://www.internalfb.com/intern/testinfra/testrun/6473924727918606 Network: Up: 0B Down: 0B (reSessionID-8b7877b7-4cf8-4850-ac7b-ee84571b005d) Command: test. Time elapsed: 1:07.6s Tests finished: Pass 4. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D73460163 Pulled By: jlbmorales
1 parent 11472c9 commit dd5c7b0

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchao/_models/sam2/modeling/sam2_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,9 +788,10 @@ def _track_step(
788788
if prev_sam_mask_logits is not None:
789789
assert point_inputs is not None and mask_inputs is None
790790
mask_inputs = prev_sam_mask_logits
791+
else:
792+
assert mask_inputs is None
791793
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
792794

793-
assert mask_inputs is None
794795
assert multimask_output
795796
if point_inputs is not None:
796797
point_inputs = {k: point_inputs[k].contiguous() for k in point_inputs}

0 commit comments

Comments
 (0)