Skip to content

Commit

Permalink
Stash dispatcher, working with types in macros is a minefield nim-lan…
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Apr 19, 2020
1 parent bde1880 commit 927f1de
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/private/ast_utils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import macros


proc hasType*(x: NimNode, t: static[string]): bool {. compileTime .} =
## Compile-time type checking
sameType(x, bindSym(t))
Expand All @@ -29,6 +28,10 @@ proc isBool*(x: NimNode): bool {. compileTime .} =
## Compile-time type checking
hasType(x, "bool")

proc isOpenarray*(x: NimNode): bool {. compileTime .} =
## Compile-time type checking
hasType(x, "array") or hasType(x, "seq") or hasType(x, "openArray")

proc isAllInt*(slice_args: NimNode): bool {. compileTime .} =
## Compile-time type checking
result = true
Expand Down
40 changes: 37 additions & 3 deletions src/tensor/private/p_accessors_macros_read.nim
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import ../../private/ast_utils,
./p_checks, ./p_accessors,
sequtils, macros

from ../init_cpu import toTensor

template slicerImpl*[T](result: AnyTensor[T]|var AnyTensor[T], slices: ArrayOfSlices): untyped =
## Slicing routine

Expand Down Expand Up @@ -131,15 +133,24 @@ proc getFancySelector*(ast: NimNode, axis: var int, selector: var NimNode): Fanc
result = None
var foundNonSpanOrEllipsis = false

template checkNonSpan(): untyped {.dirty.} =
doAssert not foundNonSpanOrEllipsis,
"Fancy indexing is only compatible with full spans `_` on non-indexed dimensions" &
" and/or ellipsis `...`"

let tensorBoolType = nnkBracketExpr.newTree(bindSym"Tensor", bindSym"bool")

var i = 0
while i < ast.len:
let cur = ast[i]

echo cur.treerepr
echo cur.getType().treerepr

if cur.eqIdent"Span":
discard
elif cur.kind == nnkBracket:
doAssert not foundNonSpanOrEllipsis,
"Fancy indexing is only compatible with full spans `_` on non-indexed dimensions" &
" and/or ellipsis `...`"
checkNonSpan()
axis = i
if cur[0].kind == nnkIntLit:
result = FancyIndex
Expand All @@ -151,6 +162,29 @@ proc getFancySelector*(ast: NimNode, axis: var int, selector: var NimNode): Fanc
else:
# byte, char, enums are all represented by integers in the VM
error "Fancy indexing is only possible with integers or booleans"
elif cur.isOpenarray:
# Only check the instantiation type, the overload will ake care of conversion
checkNonSpan()
axis = i
let curAsTensor = newCall(bindSym"toTensor", cur)
if sameType(curAsTensor, tensorBoolType):
let full = i == 0 and ast.len == 1
result = if full: FancyMaskFull else: FancyMaskAxis
selector = cur
else:
result = FancyIndex
selector = cur
elif sameType(cur, tensorBoolType):
checkNonSpan()
axis = i
let full = i == 0 and ast.len == 1
result = if full: FancyMaskFull else: FancyMaskAxis
selector = cur
elif sameType(cur, bindSym"Tensor"):
checkNonSpan()
axis = i
result = FancyIndex
selector = cur
else:
if result != None:
doAssert cur.eqIdent"..." and i == ast.len - 1
Expand Down
19 changes: 19 additions & 0 deletions tests/manual_checks/fancy_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,22 @@ def index_select():
print('--------------------------')
print('x[[1, 3], :]')
print(x[[1, 3], :])

def masked_select():
print('Masked select')
print('--------------------------')
x = np.array([[ 4, 99, 2],
[ 3, 4, 99],
[ 1, 8, 7],
[ 8, 6, 8]])

print(x)
print('--------------------------')
print('x[:, np.sum(x, axis = 0) > 50]')
print(x[:, np.sum(x, axis = 0) > 50])
print('--------------------------')
print('x[np.sum(x, axis = 1) > 50, :]')
print(x[np.sum(x, axis = 1) > 50, :])

index_select()
masked_select()
19 changes: 19 additions & 0 deletions tests/tensor/test_fancy_indexing.nim
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,22 @@ suite "Fancy indexing":
[8, 6, 8]].toTensor()

check: r == exp

test "Masked selection via fancy indexing":
block: # print('x[:, np.sum(x, axis = 0) > 50]')
let r = x[_, x.sum(axis = 0) >. 50]

let exp = [[99, 2],
[ 4, 99],
[ 8, 7],
[ 6, 8]].toTensor()

check: r == exp

block: # print('x[np.sum(x, axis = 1) > 50, :]')
let r = x[x.sum(axis = 1) >. 50, _]

let exp = [[4, 99, 2],
[3, 4, 99]].toTensor()

check: r == exp

0 comments on commit 927f1de

Please sign in to comment.