-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add RAFT model for optical flow #5022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
d51172f
Add RAFT
NicolasHug 12375b8
Merge branch 'main' of github.com:pytorch/vision into raft_model_arch
NicolasHug 578373b
Minor fixes
NicolasHug 19629a3
add _init_weights() method
NicolasHug 44b4290
weights -> pretrained
NicolasHug 7ea67ea
Use ConvNormActivation in MaskPredictor
NicolasHug 6b3af06
Use nn.Identity instead of checking for None layers
NicolasHug 173ccde
Extract out _compute_corr_volume method
NicolasHug 8e444fd
Use F.relu instead of torch.relu
NicolasHug c05b872
Re-organize file structure and rename raft() into raft_large
NicolasHug 6678cff
Added support for torchscript, and added expect test
NicolasHug b1375ab
avoid import
NicolasHug 6f9bc85
ValueError instead of NotImplementedError
NicolasHug f655ec6
Allow higher tolerance for expectTest
NicolasHug a85b6b4
The docssssssss
NicolasHug c891a93
fix hooks
NicolasHug 9ae9e38
Fix hooks -- remastered
NicolasHug f077d7c
Merge branch 'main' into raft_model_arch
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .raft import RAFT, raft_large, raft_small |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
|
||
|
||
def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None): | ||
"""Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates.""" | ||
h, w = img.shape[-2:] | ||
|
||
xgrid, ygrid = absolute_grid.split([1, 1], dim=-1) | ||
xgrid = 2 * xgrid / (w - 1) - 1 | ||
ygrid = 2 * ygrid / (h - 1) - 1 | ||
normalized_grid = torch.cat([xgrid, ygrid], dim=-1) | ||
|
||
return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners) | ||
|
||
|
||
def make_coords_grid(batch_size: int, h: int, w: int): | ||
coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") | ||
coords = torch.stack(coords[::-1], dim=0).float() | ||
return coords[None].repeat(batch_size, 1, 1, 1) | ||
|
||
|
||
def upsample_flow(flow, up_mask: Optional[Tensor] = None): | ||
"""Upsample flow by a factor of 8. | ||
|
||
If up_mask is None we just interpolate. | ||
If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B. | ||
Note that in appendix B the picture assumes a downsample factor of 4 instead of 8. | ||
""" | ||
batch_size, _, h, w = flow.shape | ||
new_h, new_w = h * 8, w * 8 | ||
|
||
if up_mask is None: | ||
return 8 * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True) | ||
|
||
up_mask = up_mask.view(batch_size, 1, 9, 8, 8, h, w) | ||
up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1 | ||
|
||
upsampled_flow = F.unfold(8 * flow, kernel_size=3, padding=1).view(batch_size, 2, 9, 1, 1, h, w) | ||
upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2) | ||
|
||
return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, 2, new_h, new_w) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm parametrizing over this because testing with
_check_jit_scriptable
unfortunately fails on very few entries, e.g.:I could add
tol
parameters to the check, but I feel like this current test is just as fine