Skip to content

Made the elvis operator a macro that checks each step of an expression #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 117 additions & 9 deletions elvis.nim
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import options
import std/[options, macros, genasts]

#true if float not 0 or NaN
template truthy*(val: float): bool = (val < 0 or val > 0)
Expand Down Expand Up @@ -26,17 +26,125 @@ template truthy*[T](val: seq[T]): bool = (val != @[])
template truthy*[T](val: Option[T]): bool = isSome(val)

# true if truthy and no exception.
template `?`*[T](val: T): bool = (try: truthy(val) except: false)
template `?`*[T](val: T): bool = (try: truthy(val) except CatchableError: false)

template truthy*[T](val: T): bool = not compiles(val.isNil())

# return left if truthy otherwise right
template `?:`*[T](l: T, r: T): T = (if ?l: l else: r)
proc flattenExpression(n: NimNode, result: var seq[NimNode]) =
## Navigates the tree, extracting each step into an expression, adding to `result`
case n.kind
of nnkCallKinds:
let cleanCall = n.copyNimTree()
case cleanCall[0].kind:
of nnkDotExpr:
cleanCall[0][0] = newEmptyNode()
result.add cleanCall
flattenExpression(n[0][0], result)
else:
cleanCall[1] = newEmptyNode()
result.add cleanCall
flattenExpression(n[1], result)

of nnkBracketExpr, nnkDotExpr:
let cleanCall = n.copyNimTree()
cleanCall[0] = newEmptyNode()
result.add cleanCall
flattenExpression(n[0], result)

else:
result.add n

proc flattenExpression(n: NimNode): seq[NimNode] =
## Navigates the tree, extracting each step into an expression, returning them
case n.kind
of nnkCallKinds:
if n.len > 1:
let cleanCall = n.copyNimTree()
case cleanCall[0].kind:
of nnkDotExpr:
cleanCall[0][0] = newEmptyNode()
result.add cleanCall
flattenExpression(n[0][0], result)
else:
cleanCall[1] = newEmptyNode()
result.add cleanCall
flattenExpression(n[1], result)
else:
result.add n

of nnkBracketExpr, nnkDotExpr:
let cleanCall = n.copyNimTree()
cleanCall[0] = newEmptyNode()
result.add cleanCall
flattenExpression(n[0], result)
else:
result.add n

proc replaceCheckedVal(expr, cached: NimNode) =
## Navigates the tree replacing the first Node of concern with the `cached` symbol
if cached != nil:
case expr.kind
of nnkBracketExpr:
expr[0] = cached
of nnkCallKinds:
if expr.len > 1:
case expr[0].kind
of nnkDotExpr:
expr[0][0] = cached
else:
expr[1] = cached
else:
discard

proc generateIfExpr(s: seq[NimNode], l, r: NimNode): NimNode =
var
lastArg: NimNode
lastExpr: NimNode

for i in countDown(s.high, 0): # iterate the flattened expression backwards as each step is one further left
let
expr = s[i]
argName = gensym(nskLet, "TruthyVar")
expr.replaceCheckedVal(lastArg)

let thisExpr =
genast(expr, argName, r):
let argName = expr
if truthy(argName):
discard # placerholder rewritten either by next expression or return expression
else:
r
if lastExpr.kind == nnkNilLit:
result = thisExpr
else:
lastExpr[1][0][1] = thisExpr

lastExpr = thisExpr
lastArg = argName

lastExpr[1][0][1] = genAst(l, r, lastArg): # Mimics the logic used prior
when l isnot Option and r is Option:
some(lastArg)
elif l is Option and r isnot Option:
lastArg.get()
else:
lastArg

result = genast(result, r):
try:
# We need to wrap the expression inside a `try` incase an exception is raised.
# Since we cache the value expressions that raise exceptions do not go right into `?`.
result
except CatchableError:
r

# return some left if truthy otherwise right
template `?:`*[T](l: T, r: Option[T]): Option[T] = (if ?l: some(l) else: r)

template `?:`*[T](l: Option[T], r: T): T = (if ?l.get(): l.get() else: r)
# return left if truthy otherwise right
macro `?:`*(l, r: untyped): untyped =
if l.kind == nnkInfix and l[0].eqIdent"?:": # We want the left hand evaluated first so make it a stmt list if it's an elvis operator
result = newStmtList(newCall("?:", newStmtList(l), r))
else:
var expr = flattenExpression(l)
result = expr.generateIfExpr(l, r)

# Assign only when left is not truthy
template `?=`*[T](l: T, r: T) = (if not(?l): l = r)
Expand All @@ -54,7 +162,7 @@ template `.?`*(left, right: untyped): untyped =
var tmp = left
if truthy(tmp): tmp.right
else: default(typeof(left.right))
except: default(typeof(left.right))
except CatchableError: default(typeof(left.right))

type Branch[T] = object
then, other: T
Expand Down
24 changes: 19 additions & 5 deletions tests.nim
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import elvis
import unittest
import tables
import options
import std/[unittest, tables, options, sequtils]

template `==`[T](left: Option[T], right: T): bool =
if isSome(left): left.get() == right else: false
Expand Down Expand Up @@ -78,7 +76,7 @@ suite "conditional access":
var s1 = @["one"]
var s2 = @["one"]
test "truthy getter": check(seq1[0].?len == 3)
test "falsey getter": check(seq1[1].?len == 0)
#test "falsey getter": check(seq1[1].?len == 0) # Make a custom `{}` or other operator that raises a `CatchableError` instead, defects are not to be caught.
test "truthy precedence": check(seq1[0].?len == 3)
test "nil check": check(nilObj.?data == nil)
test "falsy on ref": check(nilObj.?data.?val == 0)
Expand All @@ -103,6 +101,7 @@ suite "elvis number":
test "good left": check((1 ?: 2) == 1)
test "expr left": check(((1 - 1) ?: 1) == 1)


suite "elvis sequence":
test "empty left": check((seq0 ?: @[1]) == @[1])
test "good left": check((@[0] ?: @[1]) == @[0])
Expand All @@ -119,7 +118,7 @@ suite "elvis string":
suite "elvis except":
test "none left": check((tab1["two"] ?: 0) == 0)
test "good left": check((tab1["one"] ?: 0) == 1)

suite "coalesce option and option":
test "left some":
let a: Option[string] = some("a")
Expand Down Expand Up @@ -243,3 +242,18 @@ suite "short circuit raws":
proc getB(): string = raise newException(ValueError, "expensive operation")
expect ValueError:
discard getA() ?: getB()

suite "Check truthy on chaining":
test "filterit":
var a: seq[int] = @[]
var b = a.filterIt(it > 3)

check (a.filterIt(it>3)[0] ?: 3) == 3

type Story = object
name: string
body: string

var stories: seq[Story] = @[Story(name:"eh", body: "no")]

check (stories.filterIt(it.name == "asdf")[0] ?: stories[0]) == stories[0]