Skip to content

Make _get_perspective_coeffs device agnostic #9076

Open
@ptrblck

Description

@ptrblck

🐛 Describe the bug

Currently, _get_perspective_coeffs creates internal tensors on the CPU as seen here and fails in the torch.linalg.lstsq call if start/endpoints is a tensor contain data on the GPU (or another device).

The docs explain a list of list of python:ints is expected, but still tensors are allowed and do not fail.

A fix would be to create a_matrix using the device attribute of the input. An alternative would be to move the points to the host, but this would sync the code and disallow graph capture or to error out if tensors are passed.

If we want to accept tensor inputs, the tensor clone should also be fixed b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8).

Original error reported in the discussion board and reproduced using:

import torch
import torchvision.transforms.functional as TF


device = "cuda"
reference_image = torch.randn(1, 3, 224, 224, device=device)
B, C, H, W = reference_image.shape

W = 200
H = 200
# Define source points (original corners of the target image)
src_points = torch.tensor([
    [0, 0],  # Top-left
    [W - 1, 0],  # Top-right
    [W - 1, H - 1],  # Bottom-right
    [0, H - 1]  # Bottom-left
], dtype=torch.float32, device=reference_image.device)
src_points = src_points.unsqueeze(0).repeat(B, 1, 1)  # (B, 4, 2)

predicted_points = torch.tensor([
    [0, 0],  # Top-left
    [W - 10, 0],  # Top-right
    [W - 10, H - 10],  # Bottom-right
    [0, H - 10]  # Bottom-left
], dtype=torch.float32, device=reference_image.device)
predicted_points = predicted_points .unsqueeze(0).repeat(B, 1, 1)  # (B, 4, 2)

warped_images = []
for i in range(B):
    warped = TF.perspective(
        reference_image[i],
        src_points[i],
        predicted_points[i],
        interpolation=TF.InterpolationMode.BILINEAR,
        fill=0,
    )

# RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument b in method wrapper_CUDA_out_linalg_lstsq_out)```

### Versions


`0.22.0.dev20250404+cu128`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions