Skip to content

Commit

Permalink
Crop assign (#3547)
Browse files Browse the repository at this point in the history
* crop_assign operator

* use crop_assign in data par

* remove legacy copy_slice_to and assign_slice_from

* update mshadow

* _crop_assign_scalar

* __setitem__ for NDArray

* fix lint errors
  • Loading branch information
pluskid authored Oct 18, 2016
1 parent 1c49718 commit cf00ca0
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 281 deletions.
1 change: 1 addition & 0 deletions example/module/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def get_iterator(args, kv):

devs = mx.cpu() if (args.gpus is None or args.gpus == '') else [
mx.gpu(int(i)) for i in args.gpus.split(',')]
logging.info('Starting with devices %s', devs)

mod = mx.mod.Module(net, context=devs)

Expand Down
25 changes: 0 additions & 25 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,31 +380,6 @@ class NDArray {
* due to different possible convention carried by copy function.
*/
void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0);

/*!
* \brief copy a slice along any axis.
* \param from the NDArray we want to slice from
* \param slice_dim the axis we want to perform slice in
* \param start the beginning of the slice
* \param end the ending of the slice
* \param to the pre-allocated NDArray to copy the slice to
* \param priority the priority of the task
*/
void CopySliceTo(const NDArray &from, int slice_dim, index_t start, index_t end,
NDArray *to, int priority = 0);

/*!
* \brief assign a slice along any axis.
* \param from the NDArray whose value is used for assigning
* \param slice_dim the axis we want to perform slice in
* \param start the beginning of the slice
* \param end the ending of the slice
* \param to the bigger NDArray whose slice we want to assign to
* \param priority the priority of the task
*/
void AssignSliceFrom(const NDArray &from, int slice_dim, index_t start, index_t end,
NDArray *to, int priority = 0);

/*!
* \brief Perform elementwise sum over each data from source, store result into out.
* \param source the ndarray we want to sum
Expand Down
2 changes: 1 addition & 1 deletion mshadow
15 changes: 14 additions & 1 deletion python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,20 @@ def _load_general(data, targets, major_axis):
else:
for slice_idx, d_dst in d_targets:
if axis >= 0:
d_src.copy_slice_to(axis, slice_idx.start, slice_idx.stop, d_dst)
# copy slice
shape = d_src.shape
begin = np.zeros(len(shape), dtype=int)
end = np.array(shape)
begin[axis] = slice_idx.start
end[axis] = slice_idx.stop
# pylint: disable=no-member,protected-access
if d_src.context == d_dst.context:
nd.crop(d_src, begin=tuple(begin), end=tuple(end), out=d_dst)
else:
# on different device, crop and then do cross device copy
d_dst_copy = nd.crop(d_src, begin=tuple(begin), end=tuple(end))
d_dst_copy.copyto(d_dst)
# pylint: enable=no-member,protected-access
else:
d_src.copyto(d_dst)

Expand Down
153 changes: 86 additions & 67 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,28 +208,83 @@ def __setstate__(self, state):
self.__dict__.update(state)

def __setitem__(self, in_slice, value):
"""Set ndarray value"""
"""Set ndarray value.
`value` can be a scalar, an `NDArray` or numpy array of compatible shape.
The following modes are supported:
- `array[:] = value`: set all the contents
- `array[i] = value`: set the i-th slice. If the array is of dimension
`(d1, d2, d3)`, it sets value of a slice of shape `(1, d2, d3)`.
- `array[i:j] = value`: similarly, if the array is of dimension
`(d1, d2, d3)`, it sets value of a slice of shape `(j-i, d2, d3)`.
Fully-dimensional indexing is also supported. For example, if array is
of shape `(d1, d2, d3)`, one can do
- `array[:, :, :] = value`: achieving the same effect of `array[:] = value`
- `array[:, i, j:k] = value`: each index could be a python slice or an int.
"""
# pylint: disable=too-many-branches
if not self.writable:
raise ValueError('trying to assign to a readonly NDArray')
if isinstance(in_slice, int):
sliced_arr = self._at(in_slice)
sliced_arr[:] = value
return
if not isinstance(in_slice, slice) or in_slice.step is not None:
raise ValueError('NDArray only support continuous slicing on axis 0')
if in_slice.start is not None or in_slice.stop is not None:
sliced_arr = self._slice(in_slice.start, in_slice.stop)
sliced_arr[:] = value
return
if isinstance(value, NDArray):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, numeric_types):
_internal._set_value(float(value), out=self)
elif isinstance(value, (np.ndarray, np.generic)):
self._sync_copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
if isinstance(in_slice, slice):
if in_slice.step is not None:
raise ValueError('NDArray only support continuous slicing on axis 0')
if in_slice.start is not None or in_slice.stop is not None:
sliced_arr = self._slice(in_slice.start, in_slice.stop)
sliced_arr[:] = value
return
if isinstance(value, NDArray):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, numeric_types):
_internal._set_value(float(value), out=self)
elif isinstance(value, (np.ndarray, np.generic)):
self._sync_copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
if isinstance(in_slice, tuple):
# multi-dimension indexing
my_shape = self.shape
assert len(in_slice) == len(my_shape)
for slice_i in in_slice:
assert isinstance(slice_i, (slice, int))
begin = [0 for _ in my_shape]
end = [x for x in my_shape]
for i, slice_i in enumerate(in_slice):
if isinstance(slice_i, int):
assert slice_i < my_shape[i]
begin[i] = slice_i
end[i] = slice_i + 1
if isinstance(slice_i, slice):
# only support continuous slicing
assert slice_i.step is None
begin[i] = slice_i.start or 0
end[i] = slice_i.stop or my_shape[i]
assert begin[i] < end[i]
assert end[i] <= my_shape[i]
begin = tuple(begin)
end = tuple(end)
if isinstance(value, NDArray):
value = value.as_in_context(self.context)
_internal._crop_assign(self, value, out=self,
begin=begin, end=end)
elif isinstance(value, numeric_types):
_internal._crop_assign_scalar(self, out=self,
begin=begin, end=end,
scalar=value)
elif isinstance(value, (np.ndarray, np.generic)):
value = array(value, ctx=self.context)
_internal._crop_assign(self, value, out=self,
begin=begin, end=end)
else:
raise TypeError('type %s not supported' % str(type(value)))
# pylint: enable=too-many-branches

def __getitem__(self, in_slice):
"""Get ndarray"""
Expand Down Expand Up @@ -281,54 +336,6 @@ def _slice(self, start, stop):
self.handle, start, stop, ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)

def copy_slice_to(self, axis, start, stop, target):
"""Copy a slice along an axis.
Unlike `slice_axis`, the source and target can live on different contexts.
Parameters
----------
axis : int
The axis along which to do slicing.
start : int
The starting index of the slice.
stop : int
The finishing index of the slice.
target : NDArray or Context
If an NDArray, must be pre-allocated with compatible shape.
If a Context, a new NDArray will be created.
Returns
-------
The sliced copy of the NDArray.
"""
if isinstance(target, Context):
shape = list(self.shape)
shape[axis] = stop - start
target = NDArray(_new_alloc_handle(shape, target, True, self.dtype))

assert isinstance(target, NDArray)
return _internal._copy_slice_to(self, axis, start, stop, out=target)

def assign_slice_from(self, axis, start, stop, source):
"""Assign a slice from an NDArray.
The source and target can live on different contexts.
Parameters
----------
axis : int
The axis along which to do slicing.
start : int
The starting index of the slice.
stop : int
The finishing index of the slice.
source : NDArray
The array whose content is used for the assignment.
"""
assert isinstance(source, NDArray)
return _internal._assign_slice_from(source, axis, start, stop, out=self)

def _at(self, idx):
"""Return a sub NDArray that shares memory with current one.
Expand Down Expand Up @@ -969,11 +976,23 @@ def concatenate(arrays, axis=0, always_copy=True):
assert shape_rest1 == arr.shape[0:axis]
assert shape_rest2 == arr.shape[axis+1:]
assert dtype == arr.dtype
ret = empty(shape_rest1 + (shape_axis,) + shape_rest2,
ctx=arrays[0].context, dtype=dtype)
ret_shape = shape_rest1 + (shape_axis,) + shape_rest2
ret = empty(ret_shape, ctx=arrays[0].context, dtype=dtype)

idx = 0
begin = [0 for _ in ret_shape]
end = list(ret_shape)
for arr in arrays:
ret.assign_slice_from(axis, idx, idx+arr.shape[axis], arr)
if axis == 0:
ret[idx:idx+arr.shape[0]] = arr
else:
begin[axis] = idx
end[axis] = idx+arr.shape[axis]
# pylint: disable=no-member,protected-access
_internal._crop_assign(ret, arr, out=ret,
begin=tuple(begin),
end=tuple(end))
# pylint: enable=no-member,protected-access
idx += arr.shape[axis]

return ret
Expand Down
Loading

0 comments on commit cf00ca0

Please sign in to comment.