Skip to content

Commit

Permalink
add unit tests for fp16 decorators (open-mmlab#2381)
Browse files Browse the repository at this point in the history
  • Loading branch information
hellock authored Apr 3, 2020
1 parent 3138027 commit cb348b2
Showing 1 changed file with 301 additions and 0 deletions.
301 changes: 301 additions & 0 deletions tests/test_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
import numpy as np
import pytest
import torch
import torch.nn as nn

from mmdet.core import auto_fp16, force_fp32
from mmdet.core.fp16.utils import cast_tensor_type


def test_cast_tensor_type():
inputs = torch.FloatTensor([5.])
src_type = torch.float32
dst_type = torch.int32
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, torch.Tensor)
assert outputs.dtype == dst_type

inputs = 'tensor'
src_type = str
dst_type = str
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, str)

inputs = np.array([5.])
src_type = np.ndarray
dst_type = np.ndarray
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, np.ndarray)

inputs = dict(
tensor_a=torch.FloatTensor([1.]), tensor_b=torch.FloatTensor([2.]))
src_type = torch.float32
dst_type = torch.int32
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, dict)
assert outputs['tensor_a'].dtype == dst_type
assert outputs['tensor_b'].dtype == dst_type

inputs = [torch.FloatTensor([1.]), torch.FloatTensor([2.])]
src_type = torch.float32
dst_type = torch.int32
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, list)
assert outputs[0].dtype == dst_type
assert outputs[1].dtype == dst_type

inputs = 5
outputs = cast_tensor_type(inputs, None, None)
assert isinstance(outputs, int)


def test_auto_fp16():

with pytest.raises(TypeError):
# ExampleObject is not a subclass of nn.Module

class ExampleObject(object):

@auto_fp16()
def __call__(self, x):
return x

model = ExampleObject()
input_x = torch.ones(1, dtype=torch.float32)
model(input_x)

# apply to all input args
class ExampleModule(nn.Module):

@auto_fp16()
def forward(self, x, y):
return x, y

model = ExampleModule()
input_x = torch.ones(1, dtype=torch.float32)
input_y = torch.ones(1, dtype=torch.float32)
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32

model.fp16_enabled = True
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half

if torch.cuda.is_available():
model.cuda()
output_x, output_y = model(input_x.cuda(), input_y.cuda())
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half

# apply to specified input args
class ExampleModule(nn.Module):

@auto_fp16(apply_to=('x', ))
def forward(self, x, y):
return x, y

model = ExampleModule()
input_x = torch.ones(1, dtype=torch.float32)
input_y = torch.ones(1, dtype=torch.float32)
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32

model.fp16_enabled = True
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.float32

if torch.cuda.is_available():
model.cuda()
output_x, output_y = model(input_x.cuda(), input_y.cuda())
assert output_x.dtype == torch.half
assert output_y.dtype == torch.float32

# apply to optional input args
class ExampleModule(nn.Module):

@auto_fp16(apply_to=('x', 'y'))
def forward(self, x, y=None, z=None):
return x, y, z

model = ExampleModule()
input_x = torch.ones(1, dtype=torch.float32)
input_y = torch.ones(1, dtype=torch.float32)
input_z = torch.ones(1, dtype=torch.float32)
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32

model.fp16_enabled = True
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half
assert output_z.dtype == torch.float32

if torch.cuda.is_available():
model.cuda()
output_x, output_y, output_z = model(
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half
assert output_z.dtype == torch.float32

# out_fp32=True
class ExampleModule(nn.Module):

@auto_fp16(apply_to=('x', 'y'), out_fp32=True)
def forward(self, x, y=None, z=None):
return x, y, z

model = ExampleModule()
input_x = torch.ones(1, dtype=torch.half)
input_y = torch.ones(1, dtype=torch.float32)
input_z = torch.ones(1, dtype=torch.float32)
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32

model.fp16_enabled = True
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32

if torch.cuda.is_available():
model.cuda()
output_x, output_y, output_z = model(
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.float32


def test_force_fp32():

with pytest.raises(TypeError):
# ExampleObject is not a subclass of nn.Module

class ExampleObject(object):

@force_fp32()
def __call__(self, x):
return x

model = ExampleObject()
input_x = torch.ones(1, dtype=torch.float32)
model(input_x)

# apply to all input args
class ExampleModule(nn.Module):

@force_fp32()
def forward(self, x, y):
return x, y

model = ExampleModule()
input_x = torch.ones(1, dtype=torch.half)
input_y = torch.ones(1, dtype=torch.half)
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half

model.fp16_enabled = True
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32

if torch.cuda.is_available():
model.cuda()
output_x, output_y = model(input_x.cuda(), input_y.cuda())
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32

# apply to specified input args
class ExampleModule(nn.Module):

@force_fp32(apply_to=('x', ))
def forward(self, x, y):
return x, y

model = ExampleModule()
input_x = torch.ones(1, dtype=torch.half)
input_y = torch.ones(1, dtype=torch.half)
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half

model.fp16_enabled = True
output_x, output_y = model(input_x, input_y)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.half

if torch.cuda.is_available():
model.cuda()
output_x, output_y = model(input_x.cuda(), input_y.cuda())
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.half

# apply to optional input args
class ExampleModule(nn.Module):

@force_fp32(apply_to=('x', 'y'))
def forward(self, x, y=None, z=None):
return x, y, z

model = ExampleModule()
input_x = torch.ones(1, dtype=torch.half)
input_y = torch.ones(1, dtype=torch.half)
input_z = torch.ones(1, dtype=torch.half)
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half
assert output_z.dtype == torch.half

model.fp16_enabled = True
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.half

if torch.cuda.is_available():
model.cuda()
output_x, output_y, output_z = model(
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.float32
assert output_z.dtype == torch.half

# out_fp16=True
class ExampleModule(nn.Module):

@force_fp32(apply_to=('x', 'y'), out_fp16=True)
def forward(self, x, y=None, z=None):
return x, y, z

model = ExampleModule()
input_x = torch.ones(1, dtype=torch.float32)
input_y = torch.ones(1, dtype=torch.half)
input_z = torch.ones(1, dtype=torch.half)
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.float32
assert output_y.dtype == torch.half
assert output_z.dtype == torch.half

model.fp16_enabled = True
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half
assert output_z.dtype == torch.half

if torch.cuda.is_available():
model.cuda()
output_x, output_y, output_z = model(
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
assert output_x.dtype == torch.half
assert output_y.dtype == torch.half
assert output_z.dtype == torch.half

0 comments on commit cb348b2

Please sign in to comment.