-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Allow users to choose whether to return Datapoint subclasses or pure Tensor #7825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tensor = tensor.unsqueeze(0) | ||
elif tensor.ndim != 2: | ||
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") | ||
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to add the check_dims
flag because in some cases like for bbox.sum()
the dims won't be correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this checks to _wrap
for images as well?
vision/torchvision/datapoints/_image.py
Lines 39 to 42 in 3065ad5
if tensor.ndim < 2: | |
raise ValueError | |
elif tensor.ndim == 2: | |
tensor = tensor.unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not needed
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are | ||
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` | ||
listed in _FORCE_TORCHFUNCTION_SUBCLASS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: the docstring above is mostly wrong / obsolete now. If this is merged I would rewrite everything. Same for a lot of the comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not going to block, but I would prefer doing this here. I'm ok with the docs being updated later, since the default behavior doesn't change. But I feel the comments here should be updated right away, since they are wrong / obsolete now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I will be able to write better quality comments / docstring once I start writing the user-facing docs.
_TORCHFUNCTION_SUBCLASS = False | ||
|
||
|
||
def set_return_type(type="Tensor"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only function that is publicly exposed. We'll probably want to make this a context manager on top of a global flag switch. We can bikeshed on its name and refine the actual UX, but let's first focus on whether we want to expose this functionality.
|
||
|
||
# For those ops we always want to preserve the original subclass instead of returning a pure Tensor | ||
_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I'm not super happy about our names ("unwrapping" vs "subclass"), a lot of it actually coming from the base implementations. But we can clean that up later.
tensor = tensor.unsqueeze(0) | ||
elif tensor.ndim != 2: | ||
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") | ||
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this checks to _wrap
for images as well?
vision/torchvision/datapoints/_image.py
Lines 39 to 42 in 3065ad5
if tensor.ndim < 2: | |
raise ValueError | |
elif tensor.ndim == 2: | |
tensor = tensor.unsqueeze(0) |
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are | ||
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` | ||
listed in _FORCE_TORCHFUNCTION_SUBCLASS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not going to block, but I would prefer doing this here. I'm ok with the docs being updated later, since the default behavior doesn't change. But I feel the comments here should be updated right away, since they are wrong / obsolete now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamping. Thanks Nicolas!
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
… or pure Tensor (#7825) Reviewed By: matteobettini Differential Revision: D48642251 fbshipit-source-id: 9a59123410585c4b0523069089803784168ca707
This is basically the same as #7807, but preserve the current default behaviour i.e. we still return tensors by default.
This adds a
datapoints.set_return_type("datapoints")
public switch that allows users to decide whether they want datapoints or tensors as output.This does NOT change anything to the unwrap/wrapping logic of our functional kernels.
cc @vfdev-5