Skip to content

Commit

Permalink
Guard the boundary of index computed in compute_source_index_and_lamb…
Browse files Browse the repository at this point in the history
…da (pytorch#89252)

Improve the fix in pytorch#89210
See discussion in pytorch#89212 (comment)
Pull Request resolved: pytorch#89252
Approved by: https://github.com/mingfeima, https://github.com/weiwangmeta
  • Loading branch information
Jiong Gong authored and pytorchmergebot committed Nov 29, 2022
1 parent 9377230 commit 620994c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
10 changes: 8 additions & 2 deletions aten/src/ATen/native/UpSample.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,16 @@ static inline void compute_source_index_and_lambda(
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
ratio, output_index, align_corners, /*cubic=*/false);
input_index0 = static_cast<int64_t>(real_input_index);
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size - 1`, causing overflow. So we guard it with `std::min` below.
input_index0 = std::min(static_cast<int64_t>(real_input_index), input_size - 1);
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
input_index1 = input_index0 + offset;
lambda1 = real_input_index - input_index0;
lambda1 = std::min(
std::max(real_input_index - input_index0, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
lambda0 = static_cast<scalar_t>(1.) - lambda1;
}
}
Expand Down
61 changes: 61 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6903,6 +6903,67 @@ def test_interpolate_illegal_memory_access(self):
self.assertEqual(out_ref, out)
self.assertEqual(input_ref.grad, input.grad)

def test_interpolate_buffer_overflow(self):
# Test buffer overflow issue due to inaccurate floating point
# representation for integer values. See issue below for details.
# https://github.com/pytorch/pytorch/issues/88939

def helper(size, dtype, mode, device, is_channels_last):
input = torch.ones(size, dtype=dtype, device=device)
if is_channels_last:
if len(size) == 3:
input = input.transpose(1, 2).contiguous().transpose(1, 2)
elif len(size) == 4:
input = input.to(memory_format=torch.channels_last)
else:
input = input.to(memory_format=torch.channels_last_3d)
output1 = F.interpolate(input, 2, mode=mode, align_corners=True)
# reset the corner value and expect the output is changed as well
# the output won't be changed on buffer overflow
input[(-1,) * len(size)] = 0.5
output2 = F.interpolate(input, 2, mode=mode, align_corners=True)
self.assertNotEqual(output1, output2)

size_dtype_list = []
# We set the size larger than the floating point exactly representable range
# float: exact representable range (-2**24,2**24)
size_dtype_list.append(([1, 10, 2**24 + 4], torch.float))
size_dtype_list.append(([1, 10, 2, 2**24 + 4], torch.float))
size_dtype_list.append(([1, 10, 2, 2, 2**24 + 4], torch.float))
# bfloat16: exact representable range (-2**8, 2**8)
size_dtype_list.append(([1, 10, 2**8 + 4], torch.bfloat16))
size_dtype_list.append(([1, 10, 2, 2**8 + 4], torch.bfloat16))
size_dtype_list.append(([1, 10, 2, 2, 2**8 + 4], torch.bfloat16))
# half: exact representable range (-2**11, 2**11)
size_dtype_list.append(([1, 10, 2**11 + 4], torch.half))
size_dtype_list.append(([1, 10, 2, 2**11 + 4], torch.half))
size_dtype_list.append(([1, 10, 2, 2, 2**11 + 4], torch.half))

# TODO: turn on cuda test after buffer overflow issue is fixed in cuda kernel
# devices = ['cpu'] + (['cuda'] if torch.cuda.is_available() else [])
devices = ['cpu']

for mode in ('linear', 'bilinear', 'bicubic', 'trilinear'):
for size_dtype in size_dtype_list:
size, dtype = size_dtype
if (
mode == 'linear' and len(size) != 3
or (mode == 'bilinear' and len(size) != 4)
or (mode == 'bicubic' and len(size) != 4)
or (mode == 'trilinear' and len(size) != 5)
):
continue
for device in devices:
if (
device == 'cpu' and dtype == torch.half
or (device == 'cuda' and dtype == torch.bfloat16)
):
# no half precision support on cpu or bfloat16 on cuda yet
continue
for is_channels_last in (True, False):
helper(size, dtype, mode, device, is_channels_last)


def test_interpolate(self):
def _test_interpolate_helper(in_t, scale_factor, layer):
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
Expand Down

0 comments on commit 620994c

Please sign in to comment.