Skip to content

Commit c39f852

Browse files
committed
torch ndarray function backend
1 parent 10acc9e commit c39f852

18 files changed

+1179
-29
lines changed

Makefile

+35-11
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ endif
9393

9494
all: lib/libmxnet.a lib/libmxnet.so $(BIN)
9595

96-
SRC = $(wildcard src/*.cc src/*/*.cc)
97-
OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
98-
CUSRC = $(wildcard src/*/*.cu)
99-
CUOBJ = $(patsubst src/%.cu, build/%_gpu.o, $(CUSRC))
96+
SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
97+
OBJ = $(patsubst %.cc, build/%.o, $(SRC))
98+
CUSRC = $(wildcard src/*/*.cu src/*/*/*.cu)
99+
CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC))
100100

101101
ifneq ($(EXTRA_OPERATORS),)
102102
EXTRA_SRC = $(wildcard $(EXTRA_OPERATORS)/*.cc $(EXTRA_OPERATORS)/*/*.cc)
@@ -110,10 +110,23 @@ else
110110
EXTRA_CUOBJ =
111111
endif
112112

113+
# plugin
114+
ifeq ($(USE_TORCH), 1)
115+
CFLAGS += -I$(TORCH_PATH)/install/include -I$(TORCH_PATH)/install/include/TH -I$(TORCH_PATH)/install/include/THC -DMXNET_USE_TORCH=1
116+
LDFLAGS += -Wl,-export-dynamic -L$(TORCH_PATH)/install/lib -L$(TORCH_PATH)/install/lib/lua/5.1 -lluajit -lluaT -lTH -lTHC -lpaths -ltorch -lcutorch -lnn -lcunn
117+
118+
TORCH_SRC = $(wildcard plugin/torch/*.cc)
119+
PLUGIN_OBJ += $(patsubst %.cc, build/%.o, $(TORCH_SRC))
120+
TORCH_CUSRC = $(wildcard plugin/torch/*.cu)
121+
PLUGIN_CUOBJ += $(patsubst %.cu, build/%_gpu.o, $(TORCH_CUSRC))
122+
else
123+
CFLAGS += -DMXNET_USE_TORCH=0
124+
endif
125+
113126
LIB_DEP += $(DMLC_CORE)/libdmlc.a
114-
ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(LIB_DEP)
127+
ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(PLUGIN_OBJ) $(LIB_DEP)
115128
ifeq ($(USE_CUDA), 1)
116-
ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ)
129+
ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ) $(PLUGIN_CUOBJ)
117130
LDFLAGS += -lcuda
118131
endif
119132

@@ -125,16 +138,27 @@ else
125138
endif
126139

127140

128-
build/%.o: src/%.cc
141+
build/src/%.o: src/%.cc
142+
@mkdir -p $(@D)
143+
$(CXX) -std=c++0x $(CFLAGS) -MM -MT build/src/$*.o $< >build/src/$*.d
144+
$(CXX) -std=c++0x -c $(CFLAGS) -c $< -o $@
145+
146+
build/src/%_gpu.o: src/%.cu
147+
@mkdir -p $(@D)
148+
$(NVCC) $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -M -MT build/src/$*_gpu.o $< >build/src/$*_gpu.d
149+
$(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $<
150+
151+
build/plugin/%.o: plugin/%.cc
129152
@mkdir -p $(@D)
130-
$(CXX) -std=c++0x $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
153+
$(CXX) -std=c++0x $(CFLAGS) -MM -MT build/plugin/$*.o $< >build/plugin/$*.d
131154
$(CXX) -std=c++0x -c $(CFLAGS) -c $< -o $@
132155

133-
build/%_gpu.o: src/%.cu
156+
build/plugin/%_gpu.o: plugin/%.cu
134157
@mkdir -p $(@D)
135-
$(NVCC) $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -M -MT build/$*_gpu.o $< >build/$*_gpu.d
158+
$(NVCC) $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -M -MT build/plugin/$*_gpu.o $< >build/plugin/$*_gpu.d
136159
$(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $<
137160

161+
138162
$(EXTRA_OPERATORS)/build/%.o: $(EXTRA_OPERATORS)/%.cc
139163
@mkdir -p $(@D)
140164
$(CXX) -std=c++0x $(CFLAGS) -Isrc/operator -MM -MT $(EXTRA_OPERATORS)/build/$*.o $< >$(EXTRA_OPERATORS)/build/$*.d
@@ -173,7 +197,7 @@ include tests/cpp/unittest.mk
173197
test: $(TEST)
174198

175199
lint: rcpplint
176-
python2 dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src scripts python predict/python
200+
python2 dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src plugin scripts python predict/python
177201

178202
doc: doxygen
179203

include/mxnet/c_api.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,10 @@ MXNET_DLL int MXFuncDescribe(FunctionHandle fun,
382382
MXNET_DLL int MXFuncInvoke(FunctionHandle fun,
383383
NDArrayHandle *use_vars,
384384
mx_float *scalar_args,
385-
NDArrayHandle *mutate_vars);
385+
NDArrayHandle *mutate_vars,
386+
int num_params,
387+
char **param_keys,
388+
char **param_vals);
386389

387390
//--------------------------------------------
388391
// Part 3: symbolic configuration generation

include/mxnet/ndarray.h

+34-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <dmlc/type_traits.h>
1313
#include <dmlc/registry.h>
1414
#include <vector>
15+
#include <map>
1516
#include <string>
1617
#include <memory>
1718
#include "./base.h"
@@ -446,7 +447,10 @@ MXNET_API void SampleGaussian(real_t mu, real_t sigma, NDArray *out);
446447
/*! \brief definition of NDArray function */
447448
typedef std::function<void (NDArray **used_vars,
448449
real_t *scalars,
449-
NDArray **mutate_vars)> NDArrayAPIFunction;
450+
NDArray **mutate_vars,
451+
int num_params,
452+
char **param_keys,
453+
char **param_vals)> NDArrayAPIFunction;
450454
/*! \brief mask information on how functions can be exposed */
451455
enum NDArrayFunctionTypeMask {
452456
/*! \brief all the use_vars should go before scalar */
@@ -491,7 +495,8 @@ struct NDArrayFunctionReg
491495
*/
492496
inline NDArrayFunctionReg &set_function(void (*fsetvalue)(const real_t &rhs,
493497
NDArray *out)) {
494-
body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) {
498+
body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
499+
int num_params, char **param_keys, char **param_vals) {
495500
(*fsetvalue)(s[0], mutate_vars[0]);
496501
};
497502
num_mutate_vars = 1; num_scalars = 1;
@@ -507,8 +512,8 @@ struct NDArrayFunctionReg
507512
inline NDArrayFunctionReg &set_function(void (*fbinary)(const NDArray &lhs,
508513
const NDArray &rhs,
509514
NDArray *out)) {
510-
body = [fbinary] (NDArray **used_vars,
511-
real_t *s, NDArray **mutate_vars) {
515+
body = [fbinary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
516+
int num_params, char **param_keys, char **param_vals) {
512517
(*fbinary)(*used_vars[0], *used_vars[1], mutate_vars[0]);
513518
};
514519
num_use_vars = 2; num_mutate_vars = 1;
@@ -526,8 +531,8 @@ struct NDArrayFunctionReg
526531
inline NDArrayFunctionReg &set_function(void (*fscalar)(const NDArray &lhs,
527532
const real_t &rhs,
528533
NDArray *out)) {
529-
body = [fscalar] (NDArray **used_vars,
530-
real_t *s, NDArray **mutate_vars) {
534+
body = [fscalar] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
535+
int num_params, char **param_keys, char **param_vals) {
531536
(*fscalar)(*used_vars[0], s[0], mutate_vars[0]);
532537
};
533538
num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1;
@@ -544,15 +549,36 @@ struct NDArrayFunctionReg
544549
*/
545550
inline NDArrayFunctionReg &set_function(void (*funary)(const NDArray &src,
546551
NDArray *out)) {
547-
body = [funary] (NDArray **used_vars,
548-
real_t *s, NDArray **mutate_vars) {
552+
body = [funary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
553+
int num_params, char **param_keys, char **param_vals) {
549554
(*funary)(*used_vars[0], mutate_vars[0]);
550555
};
551556
num_use_vars = 1; num_mutate_vars = 1;
552557
type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget;
553558
this->add_argument("src", "NDArray", "Source input to the function.");
554559
return *this;
555560
}
561+
/*!
562+
* \brief set the function body to a unary NDArray function
563+
* this will also auto set the parameters correctly
564+
* \param funary function body to set
565+
* \return ref to the registered entry, used to set properties
566+
*/
567+
inline NDArrayFunctionReg &set_function(
568+
void (*fgeneric)(NDArray **used_vars,
569+
real_t *s,
570+
NDArray **mutate_vars,
571+
const std::map<std::string, std::string>& param)) {
572+
body = [fgeneric] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
573+
int num_params, char **param_keys, char **param_vals) {
574+
std::map<std::string, std::string> param;
575+
for (int i = 0; i < num_params; ++i) {
576+
param[param_keys[i]] = param_vals[i];
577+
}
578+
fgeneric(used_vars, s, mutate_vars, param);
579+
};
580+
return *this;
581+
}
556582
/*!
557583
* \brief set the number of mutate variables
558584
* \param n number of mutate variablesx

make/config.mk

+9
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,12 @@ USE_S3 = 0
105105

106106
# path to folders containing projects specific operators that you don't want to put in src/operators
107107
EXTRA_OPERATORS =
108+
109+
110+
#----------------------------
111+
# plugins
112+
#----------------------------
113+
114+
# whether to use torch integration. This requires installing torch.
115+
USE_TORCH = 0
116+
TORCH_PATH = $(HOME)/torch

make/osx.mk

+9
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,12 @@ USE_S3 = 0
9292

9393
# path to folders containing projects specific operators that you don't want to put in src/operators
9494
EXTRA_OPERATORS =
95+
96+
97+
#----------------------------
98+
# plugins
99+
#----------------------------
100+
101+
# whether to use torch integration. This requires installing torch.
102+
USE_TORCH = 0
103+
TORCH_PATH = $(HOME)/torch

plugin/torch/torch_base.cc

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file torch_base.cc
4+
* \brief torch_state
5+
* \author Junyuan Xie
6+
*/
7+
#include "./torch_base.h"
8+
9+
namespace mxnet {
10+
lua_State* TorchState::LuaState() {
11+
thread_local lua_State* state = NULL;
12+
if (!state) {
13+
state = luaL_newstate();
14+
luaL_openlibs(state);
15+
luaL_loadstring(state,
16+
"require 'torch'\n"
17+
"require 'nn'\n"
18+
#if MXNET_USE_CUDA
19+
"require 'cutorch'\n"
20+
"require 'cunn'\n"
21+
#if MXNET_USE_CUDNN
22+
"require 'cudnn'\n"
23+
#endif // MXNET_USE_CUDNN
24+
#endif // MXNET_USE_CUDA
25+
"local ss = require 'threads.sharedserialize'\n"
26+
"Serialize, Deserialize = ss.save, ss.load\n");
27+
int err = lua_pcall(state, 0, 0, 0);
28+
CHECK_EQ(err, 0) << lua_tostring(state, -1);
29+
}
30+
return state;
31+
}
32+
33+
template<>
34+
void TorchState::SetStream(mshadow::Stream<mshadow::cpu>* s) {
35+
return;
36+
}
37+
38+
#if MXNET_USE_CUDA
39+
template<>
40+
void TorchState::SetStream(mshadow::Stream<mshadow::gpu>* s) {
41+
TorchState::CudaState()->currentStream = mshadow::Stream<gpu>::GetStream(s);
42+
}
43+
#endif // MXNET_USE_CUDA
44+
} // namespace mxnet

0 commit comments

Comments
 (0)