Skip to content

Add RAFT model for optical flow #5022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Models and pre-trained weights
The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection and video classification.
keypoint detection, video classification, and optical flow.

.. note ::
Backward compatibility is guaranteed for loading a serialized
Expand Down Expand Up @@ -798,3 +798,16 @@ ResNet (2+1)D
:template: function.rst

torchvision.models.video.r2plus1d_18

Optical flow
============

Raft
----

.. autosummary::
:toctree: generated/
:template: function.rst

torchvision.models.optical_flow.raft_large
torchvision.models.optical_flow.raft_small
Binary file not shown.
Binary file not shown.
35 changes: 32 additions & 3 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _get_expected_file(name=None):
return expected_file


def _assert_expected(output, name, prec):
def _assert_expected(output, name, prec=None, atol=None, rtol=None):
"""Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
Expand All @@ -110,10 +110,11 @@ def _assert_expected(output, name, prec):
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError(f"The output for {filename}, is larger than 50kb")
raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
else:
expected = torch.load(expected_file)
rtol = atol = prec
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
atol = atol or prec
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)


Expand Down Expand Up @@ -818,5 +819,33 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load
assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]


@needs_cuda
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small))
@pytest.mark.parametrize("scripted", (False, True))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm parametrizing over this because testing with _check_jit_scriptable unfortunately fails on very few entries, e.g.:

Mismatched elements: 153 / 11520 (1.3%)
Greatest absolute difference: 0.0002608299255371094 at index (0, 0, 79, 45) (up to 0.0001 allowed)
Greatest relative difference: 0.021354377198448304 at index (0, 1, 53, 68) (up to 0.0001 allowed)

I could add tol parameters to the check, but I feel like this current test is just as fine

def test_raft(model_builder, scripted):

torch.manual_seed(0)

# We need very small images, otherwise the pickle size would exceed the 50KB
# As a resut we need to override the correlation pyramid to not downsample
# too much, otherwise we would get nan values (effective H and W would be
# reduced to 1)
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)

model = model_builder(corr_block=corr_block).eval().to("cuda")
if scripted:
model = torch.jit.script(model)

bs = 1
img1 = torch.rand(bs, 3, 80, 72).cuda()
img2 = torch.rand(bs, 3, 80, 72).cuda()

preds = model(img1, img2)
flow_pred = preds[-1]
# Tolerance is fairly high, but there are 2 * H * W outputs to check
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
_assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1)


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .regnet import *
from . import detection
from . import feature_extraction
from . import optical_flow
from . import quantization
from . import segmentation
from . import video
1 change: 1 addition & 0 deletions torchvision/models/optical_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .raft import RAFT, raft_large, raft_small
45 changes: 45 additions & 0 deletions torchvision/models/optical_flow/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor


def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None):
"""Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates."""
h, w = img.shape[-2:]

xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (w - 1) - 1
ygrid = 2 * ygrid / (h - 1) - 1
normalized_grid = torch.cat([xgrid, ygrid], dim=-1)

return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)


def make_coords_grid(batch_size: int, h: int, w: int):
coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch_size, 1, 1, 1)


def upsample_flow(flow, up_mask: Optional[Tensor] = None):
"""Upsample flow by a factor of 8.

If up_mask is None we just interpolate.
If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
"""
batch_size, _, h, w = flow.shape
new_h, new_w = h * 8, w * 8

if up_mask is None:
return 8 * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)

up_mask = up_mask.view(batch_size, 1, 9, 8, 8, h, w)
up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1

upsampled_flow = F.unfold(8 * flow, kernel_size=3, padding=1).view(batch_size, 2, 9, 1, 1, h, w)
upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)

return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, 2, new_h, new_w)
Loading