Skip to content

Commit

Permalink
[Fix] Fix unittest of TOF-VFI (open-mmlab#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yshuo-Li authored Apr 30, 2022
1 parent 1f1e9bc commit 15c1051
Showing 1 changed file with 2 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,21 @@ def test_tof_vfi_net():
assert model.__class__.__name__ == 'TOFlowVFINet'

# prepare data
inputs = torch.rand(1, 2, 3, 256, 248)
inputs = torch.rand(1, 2, 3, 256, 256)

# test on cpu
output = model(inputs)
assert torch.is_tensor(output)
assert output.shape == (1, 3, 256, 248)
assert output.shape == (1, 3, 256, 256)

# test on gpu
if torch.cuda.is_available():
model = model.cuda()
inputs = inputs.cuda()
output = model(inputs)
output = model(inputs, True)
assert torch.is_tensor(output)
assert output.shape == (1, 3, 256, 256)

inputs = torch.rand(1, 2, 3, 256, 256)
output = model(inputs)
assert torch.is_tensor(output)

with pytest.raises(OSError):
model.init_weights('')
with pytest.raises(TypeError):
Expand Down

0 comments on commit 15c1051

Please sign in to comment.