Skip to content

Commit

Permalink
* added sets to replace Just, fixed sorting order
Browse files Browse the repository at this point in the history
  • Loading branch information
jjtolton committed Apr 16, 2018
1 parent 1135590 commit e90c0a0
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 20 deletions.
49 changes: 32 additions & 17 deletions naga/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
def identity(x):
return x


@decorator
def message(f):
def _(x, *args, **kwargs):
Expand Down Expand Up @@ -94,10 +95,17 @@ def __init__(self, f=identity):
self.default = f
self.arrities = {}

def __mul__(self, other):
return partial(apply, other)

def pattern(self, *argtypes):
@decorator
def _dispatch(f):
self.pattern_map = [(argtypes, f), *self.pattern_map]
sortkey = lambda x: (1 if len(x[0]) < 1 else
0 if not isinstance(x[0][0], set)
else -1)

self.pattern_map = sorted([(argtypes, f), *self.pattern_map], key=sortkey)
self.maxlen = argmax(self.pattern_map, key=lambda x: len(x[0]))
return self

Expand All @@ -116,7 +124,7 @@ def declare(self, f):
fn = f

fout = self.pattern(*[*[anns.get(arg, Dispatch._) for arg in args],
*varargs])(fn)
*varargs])(fn)
return fout

def __call__(self, *args, **kwargs):
Expand All @@ -126,32 +134,40 @@ def find(args, n=self.maxlen):
return self.default

for argtypes, fn in self.pattern_map:

for a, b in itertools.zip_longest(argtypes, args[:n]):
if a is None or b is None:
for argtype, arg in itertools.zip_longest(argtypes, args[:n]):
if argtype is None or arg is None:
break
if a is Dispatch._:
if argtype is Dispatch._:
continue
if a is Dispatch.star:
if argtype is Dispatch.star:
return fn
if isinstance(a, Dispatch.Just):
if a == b:
# if isinstance(argtype, Dispatch.Just):
# if argtype == arg:
# continue
# else:
# break
if isinstance(argtype, Dispatch.regex):
if argtype(arg):
continue
else:
break
if isinstance(a, Dispatch.regex):
if a(b):

if isinstance(argtype, Dispatch.pred):
if argtype(arg):
continue
else:
break

if isinstance(a, Dispatch.pred):
if a(b):
continue
else:
if isinstance(argtype, set):
try:
if arg not in argtype:
break
else:
continue
except TypeError:
break

if not isinstance(b, a):
if not isinstance(arg, argtype):
break
else:
return fn
Expand All @@ -161,7 +177,6 @@ def find(args, n=self.maxlen):
return find(args)(*args, **kwargs)



def reductions(fn, seq, default=nil):
"""generator version of reduce that returns 1 item at a time"""

Expand Down
63 changes: 63 additions & 0 deletions tests/performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from _operator import mul
from random import choice

from naga import Dispatch, reduce


@Dispatch
def fib(*args):
raise TypeError("Unsupported type(s)")


@fib.declare
def fib_base_case(n: {0, 1}):
return n


@fib.declare
def fib_recursive_case(n: int):
return fib(n - 2) + fib(n - 1)


@fib.declare
def make_fib_list(a: {list}, b: int):
return [fib(n) for n in range(b)]


@fib.declare
def fibsumlist(a: {sum}, b: int) -> sum:
return fib(list, b)


@fib.declare
def fibmul(a: {mul}, b: int):
return reduce(a, map(fib, range(1, b)))

def empty_if_none(x):
if x is None:
return []
return x


@fib.declare
def fib(choice: {choice}) -> empty_if_none:
return choice([None, 100])

@fib.declare
def fib(choice: {choice}, n: int) -> list:
for i in range(n):
yield fib(choice)

if __name__ == '__main__':
print(fib_base_case(0))
print(fib(0))
print(fib(1))
print(fib(10))
print(fib_recursive_case(10))
print(fib(list, 10))
print(make_fib_list(list, 10))
print(fib(sum, 10))
print(fibsumlist(sum, 10))
print(fib(mul, 10))
print(fibmul(mul, 10))
print(fib(choice, 10))
22 changes: 19 additions & 3 deletions tests/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import operator
import unittest
from functools import reduce, lru_cache, partial
from functools import reduce, lru_cache

from naga.tools import apply, merge, assoc, dissoc, merge_with, \
merge_with_default, assoc_in, update_in, terminal_dict, \
Expand Down Expand Up @@ -559,12 +559,28 @@ def str_to_int(
self.assertEqual(foo('hey'), 326)

@foo.declare
def split_dstring(d: dict, k: foo.Just('key')) -> (
lambda x: x.split('_')):
def split_dstring(d: dict, k: {'key'}) -> (
lambda x: x.split('_')):
return d[k]

self.assertEqual(foo({'key': 'hey_there'}, 'key'), ['hey', 'there'])

def test_set_notation_for_dispatch(self):

@Dispatch
def fib(): 'fib'

@fib.declare
def fib(n: {0, 1}):
return n

@fib.declare
def fib(n: int):
return fib(n - 2) + fib(n - 1)


self.assertEqual(fib(10), 55)


if __name__ == '__main__':
unittest.main()

0 comments on commit e90c0a0

Please sign in to comment.