Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indices not moved to correct device in batched NMS #17

Closed
lmmx opened this issue Apr 5, 2023 · 1 comment
Closed

Indices not moved to correct device in batched NMS #17

lmmx opened this issue Apr 5, 2023 · 1 comment

Comments

@lmmx
Copy link
Contributor

lmmx commented Apr 5, 2023

I'm trying to increase the points per side but when I go over 80 I get the following

Traceback (most recent call last):
  File "/home/louis/dev/cv/segment-anything/scripts/amg.py", line 267, in <module>
    main(args)
  File "/home/louis/dev/cv/segment-anything/scripts/amg.py", line 244, in main
    masks = generator.generate(image)
  File "/home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py", line 163, in generate
    mask_data = self._generate_masks(image)
  File "/home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py", line 206, in _generate_masks
    crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
  File "/home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py", line 251, in _process_crop
    keep_by_nms = batched_nms(
  File "/home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py", line 73, in batched_nms
    return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
  File "/home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torch/jit/_trace.py", line 1220, in wrapper
    return fn(*args, **kwargs)
  File "/home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py", line 110, in _batched_nms_vanilla
    keep_mask[curr_indices[curr_keep_indices]] = True
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

I expect this means the curr_keep_indices need to be moved to the same device as the indexed keep_mask tensor.

If we breakpoint at line 251 in _process_crop of automatic_mask_generator.py

        # Remove duplicates within this crop.
        breakpoint()
        keep_by_nms = batched_nms(
            data["boxes"].float(),
            data["iou_preds"],
            torch.zeros(len(data["boxes"])),  # categories
            iou_threshold=self.box_nms_thresh,
        )
        data.filter(keep_by_nms)
  • Note torch.zeros(len(data["boxes"])), # categories does not specify a device

and enter the debugger

> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(253)_process_crop()
-> data["boxes"].float(),
(Pdb) n
> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(254)_process_crop()
-> data["iou_preds"],
(Pdb) n
> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(255)_process_crop()
-> torch.zeros(len(data["boxes"])),  # categories
(Pdb) n
> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(256)_process_crop()
-> iou_threshold=self.box_nms_thresh,
(Pdb) n
> /home/louis/dev/cv/segment-anything/segment_anything/automatic_mask_generator.py(252)_process_crop()
-> keep_by_nms = batched_nms(
(Pdb) s
--Call--
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(44)batched_nms()
-> def batched_nms(
(Pdb) n
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(68)batched_nms()
-> if not torch.jit.is_scripting() and not torch.jit.is_tracing():
(Pdb) n
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(69)batched_nms()
-> _log_api_usage_once(batched_nms)
(Pdb) n
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(72)batched_nms()
-> if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
(Pdb) n
> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(73)batched_nms()
-> return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)

and print the devices of the variables being passed through batched_nms() to _batched_nms_vanilla():

> /home/louis/miniconda3/envs/sam/lib/python3.10/site-packages/torchvision/ops/boxes.py(73)batched_nms()
-> return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
(Pdb) p boxes
tensor([[   0.,    0.,  968.,   45.],
        [   0.,    0.,  968.,   45.],
        [   0.,    0.,  968.,   45.],
        ...,
        [   3.,    0.,  995., 1332.],
        [   3.,    0.,  995., 1332.],
        [   0.,    4.,  995., 1332.]], device='cuda:0')
(Pdb) p scores
tensor([0.9919, 0.9900, 0.9894,  ..., 1.0067, 1.0069, 0.9800], device='cuda:0')
(Pdb) p idxs
tensor([0., 0., 0.,  ..., 0., 0., 0.])
(Pdb) p iou_threshold
0.7

We can confirm that the source of the error is the zeroes which we created without specifying a device, the boxes and scores are on CUDA.

To remedy it

torch.zeros_like(data["boxes"][:,0])

or

torch.zeros(len(data["boxes"]), device=data["boxes"].device)

but I think the first way is more elegant.

I expect the same edit should also be applied on line 360.

            torch.zeros(len(boxes)),  # categories

to

            torch.zeros_like(boxes[:,0]),  # categories
@lmmx
Copy link
Contributor Author

lmmx commented Apr 6, 2023

I've submitted a PR to fix this:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant