Skip to content

Commit b9c246e

Browse files
committed
branching strategies are no longer variadic
1 parent 167e4d0 commit b9c246e

File tree

2 files changed

+35
-39
lines changed

2 files changed

+35
-39
lines changed

strategies/branch/core.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@ def write(brl, x, result):
3535

3636
return onaction(fn, write)
3737

38-
def multiplex(*fns):
38+
@curry
39+
def multiplex(fns, x):
3940
""" Multiplex many branching rules into one """
40-
def multiplex_brl(x):
41-
seen = set([])
42-
for brl in fns:
43-
for nx in brl(x):
44-
if nx not in seen:
45-
seen.add(nx)
46-
yield nx
47-
return multiplex_brl
41+
seen = set([])
42+
for brl in fns:
43+
for nx in brl(x):
44+
if nx not in seen:
45+
seen.add(nx)
46+
yield nx
4847

4948
@curry
5049
def condition(cond, fn, x):
@@ -70,33 +69,30 @@ def notempty(fn, x):
7069
if not yielded:
7170
yield x
7271

73-
def do_one(*fns):
72+
@curry
73+
def do_one(fns, x):
7474
""" Execute one of the branching rules """
75-
def do_one_brl(x):
76-
yielded = False
77-
for brl in fns:
78-
for nx in brl(x):
79-
yielded = True
80-
yield nx
81-
if yielded:
82-
raise StopIteration()
83-
return do_one_brl
75+
yielded = False
76+
for brl in fns:
77+
for nx in brl(x):
78+
yielded = True
79+
yield nx
80+
if yielded:
81+
raise StopIteration()
8482

85-
def chain(*fns):
83+
@curry
84+
def chain(fns, x):
8685
"""
8786
Compose a sequence of fns so that they apply to the expr sequentially
8887
"""
89-
def chain_brl(x):
90-
if not fns:
91-
yield x
92-
raise StopIteration()
93-
94-
head, tail = fns[0], fns[1:]
95-
for nx in head(x):
96-
for nnx in chain(*tail)(nx):
97-
yield nnx
88+
if not fns:
89+
yield x
90+
raise StopIteration()
9891

99-
return chain_brl
92+
head, tail = fns[0], fns[1:]
93+
for nx in head(x):
94+
for nnx in chain(tail)(nx):
95+
yield nnx
10096

10197

10298
@curry

strategies/branch/tests/test_core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_debug():
5151
assert '4' in log
5252

5353
def test_multiplex():
54-
brl = multiplex(posdec, branch5)
54+
brl = multiplex([posdec, branch5])
5555
assert set(brl(3)) == set([2])
5656
assert set(brl(7)) == set([6, 8])
5757
assert set(brl(5)) == set([4, 6])
@@ -76,11 +76,11 @@ def ident_if_even(x):
7676
assert set(brl(5)) == set([5])
7777

7878
def test_chain():
79-
assert list(chain()(2)) == [2] # identity
80-
assert list(chain(inc, inc)(2)) == [4]
81-
assert list(chain(branch5, inc)(4)) == [4]
82-
assert set(chain(branch5, inc)(5)) == set([5, 7])
83-
assert list(chain(inc, branch5)(5)) == [7]
79+
assert list(chain([], 2)) == [2] # identity
80+
assert list(chain([inc, inc])(2)) == [4]
81+
assert list(chain([branch5, inc])(4)) == [4]
82+
assert set(chain([branch5, inc])(5)) == set([5, 7])
83+
assert list(chain([inc, branch5])(5)) == [7]
8484

8585
def test_onaction():
8686
L = []
@@ -103,6 +103,6 @@ def bad(expr):
103103
raise ValueError()
104104
yield False
105105

106-
assert list(do_one(inc)(3)) == [4]
107-
assert list(do_one(inc, bad)(3)) == [4]
108-
assert list(do_one(inc, posdec)(3)) == [4]
106+
assert list(do_one([inc])(3)) == [4]
107+
assert list(do_one([inc, bad])(3)) == [4]
108+
assert list(do_one([inc, posdec])(3)) == [4]

0 commit comments

Comments
 (0)