Skip to content

Commit

Permalink
Add masked assign of broadcastable tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Apr 19, 2020
1 parent 38000f5 commit f48b409
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 2 deletions.
45 changes: 45 additions & 0 deletions src/tensor/selectors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,57 @@ template masked_axis_fill_impl[T](t: var Tensor[T], mask: Tensor[bool] or openAr
func masked_axis_fill*[T](t: var Tensor[T], mask: Tensor[bool], axis: int, value: T or Tensor[T]) =
## Take a 1D boolean mask tensor with size equal to the `t.shape[axis]`
## The axis index that are set to true in the mask will be filled with `value`
##
## Limitation:
## If value is a Tensor, only filling via broadcastable tensors is supported at the moment
## for example if filling axis of a tensor `t` of shape [4, 3] the corresponding shapes are valid
## [4, 3].masked_axis_fill(mask = [1, 3], axis = 1, value = [4, 1])
##
## with values
## t = [[ 4, 99, 2],
## [ 3, 4, 99],
## [ 1, 8, 7],
## [ 8, 6, 8]].toTensor()
## mask = [false, true, true]
## value = [[10],
## [20],
## [30],
## [40]].toTensor()
##
## result = [[ 4, 10, 10],
## [ 3, 20, 20],
## [ 1, 30, 30],
## [ 8, 40, 40]].toTensor()
# TODO: support filling with a multidimensional tensor
let mask = mask.squeeze() # make 1D if coming from unreduced axis aggregation like sum
# TODO: squeeze exactly depending on axis to prevent accepting invalid values
masked_axis_fill_impl(t, mask, axis, value)

func masked_axis_fill*[T](t: var Tensor[T], mask: openArray[bool], axis: int, value: T or Tensor[T]) =
## Take a 1D boolean mask tensor with size equal to the `t.shape[axis]`
## The axis index that are set to true in the mask will be filled with `value`
##
## Limitation:
## If value is a Tensor, only filling via broadcastable tensors is supported at the moment
## for example if filling axis of a tensor `t` of shape [4, 3] the corresponding shapes are valid
## [4, 3].masked_axis_fill(mask = [1, 3], axis = 1, value = [4, 1])
##
## with values
## t = [[ 4, 99, 2],
## [ 3, 4, 99],
## [ 1, 8, 7],
## [ 8, 6, 8]].toTensor()
## mask = [false, true, true]
## value = [[10],
## [20],
## [30],
## [40]].toTensor()
##
## result = [[ 4, 10, 10],
## [ 3, 20, 20],
## [ 1, 30, 30],
## [ 8, 40, 40]].toTensor()
# TODO: support filling with a multidimensional tensor
masked_axis_fill_impl(t, mask, axis, value)

# Apply N-D mask along an axis
Expand Down
71 changes: 69 additions & 2 deletions tests/manual_checks/fancy_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,77 @@ def masked_axis_fill_value():
print('--------------------------')
y = x.copy()
print('y[y.sum(axis = 1) > 50, :] = -100')
y[y.sum(axis = 0) > 50, :] = -100
y[y.sum(axis = 1) > 50, :] = -100
print(y)
print('--------------------------')

def masked_axis_fill_tensor_invalid_1():
# ValueError: shape mismatch:
# value array of shape (4,) could not be broadcast
# to indexing result of shape (2,4)
print('Masked axis fill with tensor - invalid numpy syntax')
print('--------------------------')
x = np.array([[ 4, 99, 2],
[ 3, 4, 99],
[ 1, 8, 7],
[ 8, 6, 8]])

print(x)
print('--------------------------')
y = x.copy()
print('y[:, y.sum(axis = 0) > 50] = np.array([10, 20, 30, 40])')
y[:, y.sum(axis = 0) > 50] = np.array([10, 20, 30, 40])
print(y)

def masked_axis_fill_tensor_valid_1():
print('Masked axis fill with tensor - 1d tensor broadcasting')
print('--------------------------')
x = np.array([[ 4, 99, 2],
[ 3, 4, 99],
[ 1, 8, 7],
[ 8, 6, 8]])

print(x)
print('--------------------------')
y = x.copy()
print('y[:, y.sum(axis = 0) > 50] = np.array([[10], [20], [30], [40]])')
y[:, y.sum(axis = 0) > 50] = np.array([[10], [20], [30], [40]])
print(y)
print('--------------------------')
y = x.copy()
print('y[y.sum(axis = 1) > 50, :] = np.array([-10, -20, -30])')
y[y.sum(axis = 1) > 50, :] = np.array([-10, -20, -30])
print(y)
print('--------------------------')

def masked_axis_fill_tensor_valid_2():
print('Masked axis fill with tensor - multidimensional tensor')
print('--------------------------')
x = np.array([[ 4, 99, 2],
[ 3, 4, 99],
[ 1, 8, 7],
[ 8, 6, 8]])

print(x)
print('--------------------------')
y = x.copy()
print('y[:, y.sum(axis = 0) > 50] = np.array([[10, 50], [20, 60], [30, 70], [40, 80]])')
y[:, y.sum(axis = 0) > 50] = np.array([[10, 50],
[20, 60],
[30, 70],
[40, 80]])
print(y)
print('--------------------------')
y = x.copy()
print('y[y.sum(axis = 1) > 50, :] = np.array([-10, -20, -30], [-40, -50, -60])')
y[y.sum(axis = 1) > 50, :] = np.array([[-10, -20, -30],
[-40, -50, -60]])
print(y)
print('--------------------------')

# index_fill()
# masked_fill()
masked_axis_fill_value()
# masked_axis_fill_value()
masked_axis_fill_tensor_invalid_1()
# masked_axis_fill_tensor_valid_1()
# masked_axis_fill_tensor_valid_2()
58 changes: 58 additions & 0 deletions tests/tensor/test_fancy_indexing.nim
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,61 @@ suite "Fancy indexing":
[ 8, 6, 8]].toTensor()

check: y == exp

test "Masked axis assign tensor via fancy indexing - invalid Numpy syntaxes":
block: # y[:, y.sum(axis = 0) > 50] = np.array([10, 20, 30, 40])
var y = x.clone()

expect(IndexError):
y[_, y.sum(axis = 0) >. 50] = [10, 20, 30, 40].toTensor()

test "Masked axis assign broadcastable 1d tensor via fancy indexing":
block: # y[:, y.sum(axis = 0) > 50] = np.array([[10], [20], [30], [40]])
var y = x.clone()
y[_, y.sum(axis = 0) >. 50] = [[10], [20], [30], [40]].toTensor()

let exp = [[ 4, 10, 10],
[ 3, 20, 20],
[ 1, 30, 30],
[ 8, 40, 40]].toTensor()

check: y == exp

block: # y[y.sum(axis = 1) > 50, :] = np.array([-10, -20, -30])
var y = x.clone()
y[y.sum(axis = 1) >. 50, _] = [[-10, -20, -30]].toTensor()

let exp = [[-10, -20, -30],
[-10, -20, -30],
[ 1, 8, 7],
[ 8, 6, 8]].toTensor()

check: y == exp

# TODO - only broadcastable tensor assign are supported at the moment
# test "Masked axis assign multidimensional tensor via fancy indexing":
# block: # y[:, y.sum(axis = 0) > 50] = np.array([[10, 50], [20, 60], [30, 70], [40, 80]])
# var y = x.clone()
# y[_, y.sum(axis = 0) >. 50] = [[10, 50],
# [20, 60],
# [30, 70],
# [40, 80]].toTensor()
#
# let exp = [[ 4, 10, 50],
# [ 3, 20, 60],
# [ 1, 30, 70],
# [ 8, 40, 80]].toTensor()
#
# check: y == exp
#
# block: # y[y.sum(axis = 1) > 50, :] = np.array([-10, -20, -30], [-40, -50, -60])
# var y = x.clone()
# y[y.sum(axis = 1) >. 50, _] = [[-10, -20, -30],
# [-40, -50, -60]].toTensor()
#
# let exp = [[-10, -20, -30],
# [-40, -50, -60],
# [ 1, 8, 7],
# [ 8, 6, 8]].toTensor()
#
# check: y == exp

0 comments on commit f48b409

Please sign in to comment.