Skip to content

Commit

Permalink
Add C++ bindings for cuDNN (pytorch#167)
Browse files Browse the repository at this point in the history
The Python ctypes bindings overhead was high enough that it slowed down
multi-gpu training when using 4+ Maxwell GPUs.
  • Loading branch information
colesbury authored and soumith committed Oct 26, 2016
1 parent 30924ff commit ad2d413
Show file tree
Hide file tree
Showing 24 changed files with 1,112 additions and 160 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
Expand Down
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

# TODO: make this more robust
WITH_CUDA = os.path.exists('/Developer/NVIDIA/CUDA-7.5/include') or os.path.exists('/usr/local/cuda/include')
WITH_CUDNN = WITH_CUDA
DEBUG = False

################################################################################
Expand Down Expand Up @@ -81,10 +82,15 @@ def run(self):
from tools.cwrap.plugins.AutoGPU import AutoGPU
from tools.cwrap.plugins.BoolOption import BoolOption
from tools.cwrap.plugins.KwargsPlugin import KwargsPlugin
from tools.cwrap.plugins.NullableArguments import NullableArguments
from tools.cwrap.plugins.CuDNNPlugin import CuDNNPlugin
cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[
AutoGPU(condition='IS_CUDA'), THPLongArgsPlugin(), BoolOption(),
THPPlugin(), ArgcountSortPlugin(), KwargsPlugin(),
])
cwrap('torch/csrc/cudnn/cuDNN.cwrap', plugins=[
CuDNNPlugin(), NullableArguments()
])
# It's an old-style class in Python 2.7...
setuptools.command.build_ext.build_ext.run(self)

Expand Down Expand Up @@ -192,6 +198,18 @@ def run(self):
"torch/csrc/cuda/serialization.cpp",
]

if WITH_CUDNN:
main_libraries += ['cudnn']
main_sources += [
"torch/csrc/cudnn/Module.cpp",
"torch/csrc/cudnn/Conv.cpp",
"torch/csrc/cudnn/cuDNN.cpp",
"torch/csrc/cudnn/Types.cpp",
"torch/csrc/cudnn/Handles.cpp",
"torch/csrc/cudnn/CppWrapper.cpp",
]
extra_compile_args += ['-DWITH_CUDNN']

if DEBUG:
extra_compile_args += ['-O0', '-g']
extra_link_args += ['-O0', '-g']
Expand Down
15 changes: 13 additions & 2 deletions tools/cwrap/cwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,13 @@ def get_wrapper_template(self, declaration):
return self.search_plugins('get_wrapper_template', (declaration,), lambda _: None)

def get_arg_accessor(self, arg, option):
return self.search_plugins('get_arg_accessor', (arg, option), lambda arg,_: 'PyTuple_GET_ITEM(args, {})'.format(arg['idx']))
def wrap_accessor(arg, _):
if arg.get('idx') is None:
raise RuntimeError("Missing accessor for '{} {}'".format(
arg['type'], arg['name']))
return 'PyTuple_GET_ITEM(args, {})'.format(arg['idx'])

return self.search_plugins('get_arg_accessor', (arg, option), wrap_accessor)

def generate_wrapper(self, declaration):
wrapper = ''
Expand All @@ -153,7 +159,12 @@ def map_selected_arguments(self, base_fn_name, plugin_fn_name, option, arguments
result = []
for arg in arguments:
accessor = self.get_arg_accessor(arg, option)
res = getattr(self, base_fn_name)(arg, option).substitute(arg=accessor)
tmpl = getattr(self, base_fn_name)(arg, option)
if tmpl is None:
fn = 'check' if base_fn_name == 'get_type_check' else 'unpack'
raise RuntimeError("Missing type {} for '{} {}'".format(
fn, arg['type'], arg['name']))
res = tmpl.substitute(arg=accessor)
for plugin in self.plugins:
res = getattr(plugin, plugin_fn_name)(res, arg, accessor)
result.append(res)
Expand Down
159 changes: 159 additions & 0 deletions tools/cwrap/plugins/CuDNNPlugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from string import Template
from copy import deepcopy
from . import CWrapPlugin
from itertools import product

class CuDNNPlugin(CWrapPlugin):

TYPE_UNPACK = {
'THTensor*': Template('((THPVoidTensor*)$arg)->cdata'),
'int': Template('THPUtils_unpackLong($arg)'),
'cudnnDataType_t': Template('$arg'),
'cudnnHandle_t': Template('$arg'),
'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'),
'bool': Template('$arg == Py_True'),
}

TYPE_CHECK = {
'Convolution*': Template('THPWrapper_check($arg)'),
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'),
'int': Template('THPUtils_checkLong($arg)'),
'bool': Template('PyBool_Check($arg)'),
}

RETURN_WRAPPER = {
'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'),
'THTensor*': Template('return THPTensor_(New)($result);'),
}

METHODS_DECLARATION = Template("""
static PyMethodDef _THCUDNN_methods[] = {
$methods
{NULL}
};
PyMethodDef* THCUDNN_methods()
{
return _THCUDNN_methods;
}
""")

WRAPPER_TEMPLATE = Template("""\
static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
int __tuplecount = args ? PyTuple_Size(args) : 0;
int __dictcount = kwargs ? PyDict_Size(kwargs) : 0;
int __argcount = __tuplecount + __dictcount;
PyObject* tensorClass = getTensorClass(args);
THCPAutoGPU __autogpu_guard = THCPAutoGPU(args);
$options
}
THPUtils_invalidArguments(args, "$readable_name", $num_options, $expected_args);
return NULL;
END_HANDLE_TH_ERRORS
}
""")

RELEASE_ARG = Template("_${name}_guard.release();")

TYPE_NAMES = {
'THTensor*': '" THPTensorStr "',
'long': 'int',
'bool': 'bool',
'int': 'int',
}

def __init__(self):
self.declarations = []

def get_type_unpack(self, arg, option):
return self.TYPE_UNPACK.get(arg['type'], None)

def get_type_check(self, arg, option):
return self.TYPE_CHECK.get(arg['type'], None)

def get_wrapper_template(self, declaration):
arg_desc = []
for option in declaration['options']:
option_desc = [self.TYPE_NAMES.get(arg['type'], arg['type']) + ' ' + arg['name']
for arg in option['arguments']
if not arg.get('ignore_check', False)]
# TODO: this should probably go to THPLongArgsPlugin
if option_desc:
arg_desc.append('({})'.format(', '.join(option_desc)))
else:
arg_desc.append('no arguments')
arg_desc.sort(key=len)
arg_desc = ['"' + desc + '"' for desc in arg_desc]
arg_str = ', '.join(arg_desc)
readable_name = declaration['python_name']
return Template(self.WRAPPER_TEMPLATE.safe_substitute(
readable_name=readable_name, num_options=len(arg_desc),
expected_args=arg_str))

def get_return_wrapper(self, option):
return self.RETURN_WRAPPER.get(option['return'], None)

def get_arg_accessor(self, arg, option):
name = arg['name']
if name == 'self':
return 'self'
elif name == 'dataType':
return 'getCudnnDataType(tensorClass)'
elif name == 'handle':
return 'getCudnnHandle()'

def process_declarations(self, declarations):
for declaration in declarations:
declaration.setdefault('python_name', '_{}'.format(declaration['name']))
declaration['name'] = 'THCUDNN_{}'.format(declaration['name'])
self.declarations.append(declaration)
for option in declaration['options']:
for arg in option['arguments']:
if arg['name'] in ['self', 'state', 'dataType', 'handle']:
arg['ignore_check'] = True
declaration['options'] = self.filter_unique_options(declaration['options'])
return declarations

def filter_unique_options(self, options):
def signature(option):
return '#'.join(arg['type'] for arg in option['arguments'] if not 'ignore_check' in arg or not arg['ignore_check'])
seen_signatures = set()
unique = []
for option in options:
sig = signature(option)
if sig not in seen_signatures:
unique.append(option)
seen_signatures.add(sig)
return unique

def preprocessor_guard(self, code, condition):
return '#if ' + condition + '\n' + code + '#endif\n'

def process_wrapper(self, code, declaration):
if 'defined_if' in declaration:
return self.preprocessor_guard(code, declaration['defined_if'])
return code

def process_all_unpacks(self, code, option):
return 'state, ' + code

def declare_methods(self):
methods = ''
for declaration in self.declarations:
extra_flags = ' | ' + declaration.get('method_flags') if 'method_flags' in declaration else ''
if not declaration.get('only_register'):
extra_flags += ' | METH_KEYWORDS'
entry = Template(' {"$python_name", (PyCFunction)$name, METH_VARARGS$extra_flags, NULL},\n').substitute(
python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags
)
if 'defined_if' in declaration:
entry = self.preprocessor_guard(entry, declaration['defined_if'])
methods += entry
return self.METHODS_DECLARATION.substitute(methods=methods)

def process_full_file(self, code):
return code + self.declare_methods()
3 changes: 3 additions & 0 deletions tools/cwrap/plugins/THPPlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class THPPlugin(CWrapPlugin):
'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'),
'THStorage*': Template('((THPStorage*)$arg)->cdata'),
'THGenerator*': Template('((THPGenerator*)$arg)->cdata'),
'THSize*': Template('THPUtils_unpackSize($arg)'),
'void*': Template('THPUtils_unpackLong($arg)'),
'long': Template('THPUtils_unpackLong($arg)'),
'int': Template('THPUtils_unpackLong($arg)'),
Expand All @@ -38,6 +39,7 @@ class THPPlugin(CWrapPlugin):
'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'),
'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'),
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
'THSize*': Template('(PyObject*)Py_TYPE($arg) == THPSizeClass'),
'void*': Template('THPUtils_checkLong($arg)'),
'long': Template('THPUtils_checkLong($arg)'),
'int': Template('THPUtils_checkLong($arg)'),
Expand Down Expand Up @@ -152,6 +154,7 @@ class THPPlugin(CWrapPlugin):
'THIndexTensor*': '" THPModuleStr "LongTensor',
'THFloatTensor*': '" THPModuleStr "FloatTensor',
'THDoubleTensor*': '" THPModuleStr "DoubleTensor',
'THSize*': 'torch.Size',
'long': 'int',
'real': '" RealStr "',
'double': 'float',
Expand Down
1 change: 1 addition & 0 deletions tools/cwrap/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ def process_option_code_template(self, template, option):
from .ReturnArguments import ReturnArguments
from .GILRelease import GILRelease
from .AutoGPU import AutoGPU
from .CuDNNPlugin import CuDNNPlugin
Loading

0 comments on commit ad2d413

Please sign in to comment.