Skip to content

Commit

Permalink
Add flags to fix half comparison and test (pytorch#11395)
Browse files Browse the repository at this point in the history
Summary:
The controller you requested could not be found.  found there are some issues when using comparison operators for half types when certain THC header are included. I was able to reproduce and added a test. I also fix the issue by adding the proper definitions to avoid this issue.

Reported in pytorch#10301 (comment)
Related: pytorch/tutorials#292

soumith fmassa
Pull Request resolved: pytorch#11395

Differential Revision: D9725102

Pulled By: goldsborough

fbshipit-source-id: 630425829046bbebea3409bb792a9d62c91f41ad
  • Loading branch information
goldsborough authored and facebook-github-bot committed Sep 10, 2018
1 parent 18e5fd3 commit 35008e0
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ test/data/legacy_modules.t7
test/data/legacy_serialized.pt
test/data/linear.pt
test/htmlcov
test/cpp_extensions/install/
third_party/build/
tools/shared/_utils_internal.py
torch.egg-info/
Expand Down
Empty file.
19 changes: 19 additions & 0 deletions test/cpp_extensions/half_support.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <torch/torch.h>

#include <THC/THCNumerics.cuh>

template <typename T, typename U>
__global__ void half_test_kernel(const T* input, U* output) {
if (input[0] < input[1] || input[0] >= input[1]) {
output[0] = 123;
}
}

at::Tensor half_test(at::Tensor input) {
auto output = at::empty(1, input.options().dtype(at::kFloat));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
half_test_kernel<scalar_t>
<<<1, 1>>>(input.data<scalar_t>(), output.data<float>());
});
return output;
}
41 changes: 41 additions & 0 deletions test/test_cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,47 @@ def test_complex_registration(self):

torch.empty(2, 2, dtype=torch.complex64)

@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_half_support(self):
'''
Checks for an issue with operator< ambiguity for half when certain
THC headers are included.
See https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
for the corresponding issue.
'''
cuda_source = '''
#include <THC/THCNumerics.cuh>
template<typename T, typename U>
__global__ void half_test_kernel(const T* input, U* output) {
if (input[0] < input[1] || input[0] >= input[1]) {
output[0] = 123;
}
}
at::Tensor half_test(at::Tensor input) {
auto output = at::empty(1, input.options().dtype(at::kFloat));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
half_test_kernel<scalar_t><<<1, 1>>>(
input.data<scalar_t>(),
output.data<float>());
});
return output;
}
'''

module = torch.utils.cpp_extension.load_inline(
name='half_test_extension',
cpp_sources='at::Tensor half_test(at::Tensor input);',
cuda_sources=cuda_source,
functions=['half_test'],
verbose=True)

x = torch.randn(3, device='cuda', dtype=torch.half)
result = module.half_test(x)
self.assertEqual(result[0], 123)


if __name__ == '__main__':
common.run_tests()
10 changes: 8 additions & 2 deletions torch/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def _find_cuda_home():
# it the below pattern.
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')

COMMON_NVCC_FLAGS = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]


def is_binary_build():
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
Expand Down Expand Up @@ -165,7 +171,7 @@ def unix_wrap_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
self.compiler.set_executable('compiler_so', nvcc)
if isinstance(cflags, dict):
cflags = cflags['nvcc']
cflags += ['--compiler-options', "'-fPIC'"]
cflags = COMMON_NVCC_FLAGS + ['--compiler-options', "'-fPIC'"] + cflags
elif isinstance(cflags, dict):
cflags = cflags['cxx']
# NVCC does not allow multiple -std to be passed, so we avoid
Expand Down Expand Up @@ -831,7 +837,7 @@ def _write_ninja_file(path,
flags = ['cflags = {}'.format(' '.join(cflags))]

if with_cuda:
cuda_flags = common_cflags
cuda_flags = common_cflags + COMMON_NVCC_FLAGS
if sys.platform == 'win32':
cuda_flags = _nt_quote_args(cuda_flags)
else:
Expand Down

0 comments on commit 35008e0

Please sign in to comment.