diff --git a/segment_anything/utils/onnx.py b/segment_anything/utils/onnx.py index 4297b3129..493950ab2 100644 --- a/segment_anything/utils/onnx.py +++ b/segment_anything/utils/onnx.py @@ -81,8 +81,8 @@ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) - align_corners=False, ) - prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) - masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] orig_im_size = orig_im_size.to(torch.int64) h, w = orig_im_size[0], orig_im_size[1]