Skip to content

Commit

Permalink
Change bsum settings for neon conv layer unit tests
Browse files Browse the repository at this point in the history
fix loader install location for sysinstall
  • Loading branch information
apark263 committed May 24, 2016
1 parent fbb5033 commit 4792199
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 28 deletions.
File renamed without changes.
41 changes: 25 additions & 16 deletions neon/backends/layer_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def fprop_slice(self, q, S, X, padding, strides):
def bprop_slice(self, x, S, Q, padding, strides):
qs = x - (S - padding - 1)
firstF = None
for s in range(S): #TODO remove loop logic here.
for s in range(S): # TODO remove loop logic here.
q = qs + s
if q % strides == 0:
q //= strides
Expand All @@ -128,8 +128,8 @@ def bprop_slice(self, x, S, Q, padding, strides):
lastF = s
lastE = q
if firstF is None:
return (slice(0,0,1), slice(0,0,1), 0)
return (slice(firstF,lastF+1,strides), slice(firstE,lastE+1,1), 0)
return (slice(0, 0, 1), slice(0, 0, 1), 0)
return (slice(firstF, lastF+1, strides), slice(firstE, lastE+1, 1), 0)

def compound_ops(self, O, X, bias, bsum, relu, brelu, slope):
if bias is not None:
Expand Down Expand Up @@ -189,8 +189,8 @@ def xprop_conv(self, I, F, O, X=None, bias=None, bsum=None, alpha=1.0, beta=0.0,
return

if backward:
# C <=> K and mirror T,R,S (0,1,2,3,4) => (4,1,2,3,0)
F = np.transpose(F[:,::-1,::-1,::-1,:], (4,1,2,3,0)).copy()
# C <=> K and mirror T, R, S (0, 1, 2, 3, 4) => (4, 1, 2, 3, 0)
F = np.transpose(F[:, ::-1, ::-1, ::-1, :], (4, 1, 2, 3, 0)).copy()
mSlice, pSlice, qSlice = self.dSlice, self.hSlice, self.wSlice
else:
mSlice, pSlice, qSlice = self.mSlice, self.pSlice, self.qSlice
Expand All @@ -204,13 +204,14 @@ def xprop_conv(self, I, F, O, X=None, bias=None, bsum=None, alpha=1.0, beta=0.0,
for q in range(Q):
sliceS, sliceW, _ = qSlice[q]

slicedF = F[:,sliceT,sliceR,sliceS,:].reshape((-1, K))
slicedI = I[:,sliceD,sliceH,sliceW,:].reshape((-1, N))
slicedF = F[:, sliceT, sliceR, sliceS, :].reshape((-1, K))
slicedI = I[:, sliceD, sliceH, sliceW, :].reshape((-1, N))

if beta:
O[:,m,p,q,:] = alpha * np.dot( slicedF.T, slicedI ) + beta * X[:,m,p,q,:]
O[:, m, p, q, :] = alpha * np.dot(slicedF.T, slicedI) + \
beta * X[:, m, p, q, :]
else:
O[:,m,p,q,:] = np.dot( slicedF.T, slicedI )
O[:, m, p, q, :] = np.dot(slicedF.T, slicedI)

if not beta:
self.compound_ops(O, X, bias, bsum, relu, brelu, slope)
Expand Down Expand Up @@ -249,7 +250,7 @@ def update_conv(self, I, E, U, alpha=1.0, beta=0.0):

slicedI = I[:, sliceD, sliceH, sliceW, :].reshape((-1, N))
slicedE = E[:, m, p, q, :]
update = np.dot(slicedI, slicedE.T).reshape((C, tlen, rlen, slen, K))
update = np.dot(slicedI, slicedE.T).reshape((C, tlen, rlen, slen, K))
if alpha == 1.0:
U[:, sliceT, sliceR, sliceS, :] += update
else:
Expand Down Expand Up @@ -327,12 +328,20 @@ def __init__(self, lib, dtype,
# nOut has to change because P and Q are now the inputs
self.nOut = reduce(mul, self.DHW, 1) * C

self.dSlice = [self.bprop_slice(d, T, M, pad_d, str_d) for d in range(D)]
self.hSlice = [self.bprop_slice(h, R, P, pad_h, str_h) for h in range(H)]
self.wSlice = [self.bprop_slice(w, S, Q, pad_w, str_w) for w in range(W)]
self.mSlice = [self.fprop_slice(m, T, D, pad_d, str_d) for m in range(M)]
self.pSlice = [self.fprop_slice(p, R, H, pad_h, str_h) for p in range(P)]
self.qSlice = [self.fprop_slice(q, S, W, pad_w, str_w) for q in range(Q)]
if all(x == 1 for x in self.TRS) and \
all(p == 0 for p in self.padding) and \
all(s == 1 for s in self.strides):

self.dot = True
else:
self.dot = False

self.dSlice = [self.bprop_slice(d, T, M, pad_d, str_d) for d in range(D)]
self.hSlice = [self.bprop_slice(h, R, P, pad_h, str_h) for h in range(H)]
self.wSlice = [self.bprop_slice(w, S, Q, pad_w, str_w) for w in range(W)]
self.mSlice = [self.fprop_slice(m, T, D, pad_d, str_d) for m in range(M)]
self.pSlice = [self.fprop_slice(p, R, H, pad_h, str_h) for p in range(P)]
self.qSlice = [self.fprop_slice(q, S, W, pad_w, str_w) for q in range(Q)]


class PoolLayer(object):
Expand Down
3 changes: 2 additions & 1 deletion neon/backends/nervanacpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,8 @@ def bprop_conv(self, layer, F, E, grad_I,
grad_I *= (X > 0) + slope*(X < 0)
can be combined with bsum tensor to output bprop_bias
"""
layer.xprop_conv(E, F, grad_I, X, bias, bsum, alpha, beta, relu, brelu, slope, backward=True)
layer.xprop_conv(E, F, grad_I, X, bias, bsum, alpha, beta, relu, brelu, slope,
backward=True)

def update_conv(self, layer, I, E, U, alpha=1.0, beta=0.0):
"""
Expand Down
15 changes: 6 additions & 9 deletions neon/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,16 +706,15 @@ def __init__(self, fshape, strides={}, padding={}, init=None, bsum=False,
name=None, parallelism="Data"):
super(Convolution, self).__init__(init, name, parallelism)
self.nglayer = None
bsum = bsum and not self.be.deterministic
self.bsum = bsum
self.convparams = {'str_h': 1, 'str_w': 1, 'str_d': 1,
'pad_h': 0, 'pad_w': 0, 'pad_d': 0,
'T': 1, 'D': 1, 'bsum': bsum} # 3D paramaters
'T': 1, 'D': 1} # 3D paramaters

# keep around args in __dict__ for get_description.
self.fshape = fshape
self.strides = strides
self.padding = padding
self.bsum = bsum

if isinstance(fshape, tuple) or isinstance(fshape, list):
fkeys = ('R', 'S', 'K') if len(fshape) == 3 else ('T', 'R', 'S', 'K')
Expand Down Expand Up @@ -767,7 +766,7 @@ def configure(self, in_obj):
self.out_shape = (K, P, Q) if M == 1 else (K, M, P, Q)
if self.weight_shape is None:
self.weight_shape = self.nglayer.dimF2 # (C * R * S, K)
if self.convparams['bsum']:
if self.bsum:
self.batch_sum_shape = (self.nglayer.K, 1)
return self

Expand Down Expand Up @@ -833,16 +832,14 @@ def __init__(self, fshape, strides={}, padding={}, init=None, bsum=False,
name=None):
super(Deconvolution, self).__init__(init, name)
self.nglayer = None
bsum = bsum and not self.be.deterministic
self.bsum = bsum
self.deconvparams = {'str_h': 1, 'str_w': 1, 'str_d': 1,
'pad_h': 0, 'pad_w': 0, 'pad_d': 0,
'bsum': bsum}
'pad_h': 0, 'pad_w': 0, 'pad_d': 0}

# keep around args in __dict__ for get_description.
self.fshape = fshape
self.strides = strides
self.padding = padding
self.bsum = bsum

if isinstance(fshape, tuple):
# fshape[2] should now map to C (nifm)
Expand Down Expand Up @@ -884,7 +881,7 @@ def configure(self, in_obj):
self.out_shape = (self.nglayer.C, self.nglayer.H, self.nglayer.W)
if self.weight_shape is None:
self.weight_shape = self.nglayer.dimF2 # (C * R * S, K)
if self.deconvparams['bsum']:
if self.bsum:
self.batch_sum_shape = (self.nglayer.C, 1)
return self

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# ----------------------------------------------------------------------------

[flake8]
exclude = .git,__init__.py,neon/backends/test_pool.py,neon/backends/kernel_specs.py,neon/backends/convnet-benchmarks.py,neon/backends/layer_gpu.py,neon/backends/nervanagpu.py,neon/backends/float_ew.py,neon/backends/make_kernels.py,neon/backends/kernels/cuda/pooling.py,neon/backends/cuda_batchnorm.py,neon/backends/winograd4.py,neon/backends/winograd.py,neon/backends/winograd_conv.py,neon/backends/winograd4.py,neon/backends/winograd5.py,neon/backends/convolution.py
exclude = .git,__init__.py,neon/backends/test_pool.py,neon/backends/kernel_specs.py,neon/backends/convnet-benchmarks.py,neon/backends/layer_gpu.py,neon/backends/nervanagpu.py,neon/backends/float_ew.py,neon/backends/make_kernels.py,neon/backends/kernels/cuda/pooling.py,neon/backends/cuda_batchnorm.py,neon/backends/winograd4.py,neon/backends/winograd.py,neon/backends/winograd_conv.py,neon/backends/winograd4.py,neon/backends/winograd5.py,neon/backends/convolution.py,neon/backends/conv_kernel_test.py
max-line-length = 99
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
'backends/kernels/cubin/*.cubin',
'backends/kernels/maxas/*.pl',
'backends/kernels/maxas/MaxAs/*.pm',
'data/loader/*.so']},
'../loader/bin/*.so']},
classifiers=['Development Status :: 3 - Alpha',
'Environment :: Console',
'Environment :: Console :: Curses',
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,14 @@ def cleanup():
# backend or use the NervanaObject.be global
return be


def get_backend_pair(device_id, dtype=np.float32, bench=False):
from neon.backends.nervanagpu import NervanaGPU
ng = NervanaGPU(default_dtype=dtype, bench=bench, device_id=device_id)
nc = NervanaCPU(default_dtype=dtype)
return (ng, nc)


@pytest.fixture(scope='module')
def backend_pair(request):
ng, nc = get_backend_pair(device_id=request.config.getoption("--device_id"))
Expand Down

0 comments on commit 4792199

Please sign in to comment.