Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-99] Upgrade to cuda 9.1 cudnn 7 #10108

Merged
merged 4 commits into from
Mar 21, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Address review comments
  • Loading branch information
marcoabreu committed Mar 21, 2018
commit b76809fb22add03b93922a2ad52cfd4f202e2caf
9 changes: 3 additions & 6 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,13 @@ def set_gradient_compression(self, compression_params):
Other keys in this dictionary are optional and specific to the type
of gradient compression.
"""
# pylint: disable=unsupported-membership-test
if ('device' in self.type) or ('dist' in self.type):
if ('device' in self.type) or ('dist' in self.type): # pylint: disable=unsupported-membership-test
ckeys, cvals = _ctype_dict(compression_params)
check_call(_LIB.MXKVStoreSetGradientCompression(self.handle,
mx_uint(len(compression_params)),
ckeys, cvals))
else:
raise Exception('Gradient compression is not supported for this type of kvstore')
# pylint: enable=unsupported-membership-test

def set_optimizer(self, optimizer):
""" Registers an optimizer with the kvstore.
Expand Down Expand Up @@ -467,8 +465,8 @@ def set_optimizer(self, optimizer):
is_worker = ctypes.c_int()
check_call(_LIB.MXKVStoreIsWorkerNode(ctypes.byref(is_worker)))

# pylint: disable=invalid-name,unsupported-membership-test
if 'dist' in self.type and is_worker.value:
# pylint: disable=invalid-name
if 'dist' in self.type and is_worker.value: # pylint: disable=unsupported-membership-test
# send the optimizer to server
try:
# use ASCII protocol 0, might be slower, but not a big ideal
Expand All @@ -478,7 +476,6 @@ def set_optimizer(self, optimizer):
self._send_command_to_servers(0, optim_str)
else:
self._set_updater(opt.get_updater(optimizer))
# pylint: enable=unsupported-membership-test

@property
def type(self):
Expand Down
4 changes: 1 addition & 3 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,9 +1683,7 @@ def shape(self):
pdata = ctypes.POINTER(mx_uint)()
check_call(_LIB.MXNDArrayGetShape(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
# pylint: disable=invalid-slice-index
return tuple(pdata[:ndim.value])
# pylint: enable=invalid-slice-index
return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index


@property
Expand Down
4 changes: 1 addition & 3 deletions python/mxnet/symbol/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,12 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
else:
keys.append(k)
vals.append(v)"""%(func_name.lower()))
# pylint: disable=using-constant-test
if key_var_num_args:
if key_var_num_args: # pylint: disable=using-constant-test
code.append("""
if '%s' not in kwargs:
keys.append('%s')
vals.append(len(sym_args) + len(sym_kwargs))"""%(
key_var_num_args, key_var_num_args))
# pylint: enable=using-constant-test

code.append("""
return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%(
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,6 @@ def __len__(self):
output_count = mx_uint()
check_call(_LIB.MXSymbolGetNumOutputs(self.handle, ctypes.byref(output_count)))
return output_count.value
# pylint: enable=invalid-length-returned

def list_auxiliary_states(self):
"""Lists all the auxiliary states in the symbol.
Expand Down
20 changes: 9 additions & 11 deletions python/mxnet/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,19 @@ def generic_torch_function(*args, **kwargs):
for k in kwargs:
kwargs[k] = str(kwargs[k])

# pylint: disable=invalid-slice-index
check_call(_LIB.MXFuncInvokeEx( \
handle, \
c_handle_array(ndargs[n_mutate_vars:]), \
c_array(mx_float, []), \
c_handle_array(ndargs[:n_mutate_vars]),
ctypes.c_int(len(kwargs)),
c_str_array(kwargs.keys()),
c_str_array(kwargs.values())))
check_call(_LIB.MXFuncInvokeEx(
handle,
c_handle_array(ndargs[n_mutate_vars:]), # pylint: disable=invalid-slice-index
c_array(mx_float, []),
c_handle_array(ndargs[:n_mutate_vars]), # pylint: disable=invalid-slice-index
ctypes.c_int(len(kwargs)),
c_str_array(kwargs.keys()),
c_str_array(kwargs.values())))

if n_mutate_vars == 1:
return ndargs[0]
else:
return ndargs[:n_mutate_vars]
# pylint: enable=invalid-slice-index
return ndargs[:n_mutate_vars] # pylint: disable=invalid-slice-index

# End of function declaration
ret_function = generic_torch_function
Expand Down