Skip to content

Commit 25580a3

Browse files
committed
torch related fix (apache#2016)
* we actually don't need to link against libnn * torch module fix * adam optimizer fix to match torch * lint * add save/load params * lint * lint
1 parent 1afccc6 commit 25580a3

File tree

8 files changed

+64
-52
lines changed

8 files changed

+64
-52
lines changed

make/config.mk

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ EXTRA_OPERATORS =
112112
#----------------------------
113113

114114
# whether to use torch integration. This requires installing torch.
115+
# You also need to add TORCH_PATH/install/lib to your LD_LIBRARY_PATH
115116
# TORCH_PATH = $(HOME)/torch
116117
# MXNET_PLUGINS += plugin/torch/torch.mk
117118

plugin/torch/torch.mk

+2-17
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,7 @@
11
CFLAGS += -I$(TORCH_PATH)/install/include -I$(TORCH_PATH)/install/include/TH -I$(TORCH_PATH)/install/include/THC/ -DMXNET_USE_TORCH=1
2-
LDFLAGS += -L$(TORCH_PATH)/install/lib -lluajit -lluaT -lTH -lTHC -L$(TORCH_PATH)/install/lib/lua/5.1 -lpaths -ltorch
3-
4-
ifneq ("$(wildcard $(TORCH_PATH)/install/lib/lua/5.1/libnn.so)","")
5-
LDFLAGS += -lnn
6-
else
7-
LDFLAGS += -lnnx
8-
endif
9-
10-
ifeq ($(USE_CUDA), 1)
11-
LDFLAGS += -lcutorch
12-
ifneq ("$(wildcard $(TORCH_PATH)/install/lib/lua/5.1/libcunn.so)","")
13-
LDFLAGS += -lcunn
14-
else
15-
LDFLAGS += -lcunnx
16-
endif
17-
endif
2+
LDFLAGS += -L$(TORCH_PATH)/install/lib -lluajit -lluaT -lTH -lTHC
183

194
TORCH_SRC = $(wildcard plugin/torch/*.cc)
205
PLUGIN_OBJ += $(patsubst %.cc, build/%.o, $(TORCH_SRC))
216
TORCH_CUSRC = $(wildcard plugin/torch/*.cu)
22-
PLUGIN_CUOBJ += $(patsubst %.cu, build/%_gpu.o, $(TORCH_CUSRC))
7+
PLUGIN_CUOBJ += $(patsubst %.cu, build/%_gpu.o, $(TORCH_CUSRC))

plugin/torch/torch_module-inl.h

+9-5
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,16 @@ class TorchModuleProp : public OperatorProperty {
310310
}
311311

312312
virtual std::vector<std::string> ListOutputs() const {
313-
std::vector<std::string> ret;
314-
std::string output = "output";
315-
for (uint32_t i = 0; i < param_.num_outputs; ++i) {
316-
ret.push_back(output + "_" + std::to_string(i));
313+
if (param_.num_outputs > 1) {
314+
std::vector<std::string> ret;
315+
std::string output = "output";
316+
for (uint32_t i = 0; i < param_.num_outputs; ++i) {
317+
ret.push_back(output + std::to_string(i));
318+
}
319+
return ret;
320+
} else {
321+
return {"output"};
317322
}
318-
return ret;
319323
}
320324
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
321325
param_.Init(kwargs);

python/mxnet/io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
class DataBatch(object):
2222
"""Default object for holding a mini-batch of data and related information."""
23-
def __init__(self, data, label, pad, index,
23+
def __init__(self, data, label, pad=None, index=None,
2424
bucket_key=None, provide_data=None, provide_label=None):
2525
self.data = data
2626
self.label = label

python/mxnet/module/base_module.py

+34
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,40 @@ def set_params(self, arg_params, aux_params):
461461
self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params,
462462
allow_missing=False, force_init=True)
463463

464+
def save_params(self, fname):
465+
"""Save model parameters to file.
466+
467+
Parameters
468+
----------
469+
fname : str
470+
Path to output param file.
471+
"""
472+
arg_params, aux_params = self.get_params()
473+
save_dict = {('arg:%s' % k) : v for k, v in arg_params.items()}
474+
save_dict.update({('aux:%s' % k) : v for k, v in aux_params.items()})
475+
ndarray.save(fname, save_dict)
476+
477+
def load_params(self, fname):
478+
"""Load model parameters from file.
479+
480+
Parameters
481+
----------
482+
fname : str
483+
Path to input param file.
484+
"""
485+
save_dict = ndarray.load(fname)
486+
arg_params = {}
487+
aux_params = {}
488+
for k, value in save_dict.items():
489+
arg_type, name = k.split(':', 1)
490+
if arg_type == 'arg':
491+
arg_params[name] = value
492+
elif arg_type == 'aux':
493+
aux_params[name] = value
494+
else:
495+
raise ValueError("Invalid param file " + fname)
496+
self.set_params(arg_params, aux_params)
497+
464498
################################################################################
465499
# Computations
466500
################################################################################

python/mxnet/optimizer.py

+13-28
Original file line numberDiff line numberDiff line change
@@ -483,15 +483,13 @@ class Adam(Optimizer):
483483
clip_gradient : float, optional
484484
clip gradient in range [-clip_gradient, clip_gradient]
485485
"""
486-
def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, epsilon=1e-8,
486+
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
487487
decay_factor=(1 - 1e-8), **kwargs):
488488
super(Adam, self).__init__(learning_rate=learning_rate, **kwargs)
489489
self.beta1 = beta1
490490
self.beta2 = beta2
491491
self.epsilon = epsilon
492492
self.decay_factor = decay_factor
493-
self.time = 0
494-
self.time_first_index = None
495493

496494
def create_state(self, index, weight):
497495
"""Create additional optimizer state: mean, variance
@@ -502,7 +500,6 @@ def create_state(self, index, weight):
502500
The weight data
503501
504502
"""
505-
self.time_first_index = None # time is incremented only on the first index
506503
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
507504
zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance
508505

@@ -528,37 +525,25 @@ def update(self, index, weight, grad, state):
528525
lr = self._get_lr(index)
529526
self._update_count(index)
530527

528+
t = self._index_update_count[index]
531529
mean, variance = state
532530

533-
# increment time only when the first parameters is called
534-
if self.time_first_index is None:
535-
self.time_first_index = index
536-
self.time = 0 # all parameters share the same time
537-
elif self.time_first_index == index:
538-
self.time += 1
531+
grad *= self.rescale_grad
532+
if self.clip_gradient is not None:
533+
clip(grad, -self.clip_gradient, self.clip_gradient, out=grad)
539534

540-
t1 = self.time + 1
541-
learning_rate = (lr *
542-
math.sqrt(1. - self.beta2**t1) /
543-
(1. - self.beta1**t1))
544-
beta_1t = self.beta1 * self.decay_factor ** (t1 - 1)
535+
mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
536+
variance[:] = self.beta2 * variance + (1. - self.beta2) * grad * grad
545537

546-
grad = grad * self.rescale_grad
547-
if self.clip_gradient is not None:
548-
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
538+
coef1 = 1. - self.beta1**t
539+
coef2 = 1. - self.beta2**t
540+
lr *= math.sqrt(coef2)/coef1
541+
542+
weight[:] -= lr*mean/(sqrt(variance) + self.epsilon)
549543

550-
mean_t = beta_1t * mean + (1. - beta_1t) * grad
551-
variance_t = (self.beta2 * variance +
552-
(1. - self.beta2) * grad * grad)
553-
step = (learning_rate * mean_t /
554-
(sqrt(variance_t) + self.epsilon))
555544
wd = self._get_wd(index)
556545
if wd > 0.:
557-
step += lr * wd * weight
558-
559-
weight[:] += -step
560-
mean[:] = mean_t
561-
variance[:] = variance_t
546+
weight[:] -= (lr * wd) * weight
562547

563548
@register
564549
class AdaGrad(Optimizer):

python/mxnet/symbol.py

+3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def __pow__(self, other):
9898
else:
9999
raise TypeError('type %s not supported' % str(type(other)))
100100

101+
def __neg__(self):
102+
return self.__mul__(-1.0)
103+
101104
def __del__(self):
102105
check_call(_LIB.MXSymbolFree(self.handle))
103106

python/mxnet/torch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
try:
1515
_LUAJIT = ctypes.CDLL("libluajit.so", mode=ctypes.RTLD_GLOBAL)
1616
except OSError:
17-
pass
17+
_LUAJIT = None
1818

1919
# pylint: disable=too-many-locals, invalid-name
2020
def _make_torch_function(handle):

0 commit comments

Comments
 (0)