Skip to content

Commit cd43395

Browse files
wweicMarisaKirisame
authored andcommitted
[Relay] Add list update to prelude (apache#2866)
1 parent e01b6c0 commit cd43395

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

python/tvm/relay/prelude.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,25 @@ def define_list_nth(self):
6262
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
6363
self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
6464

65+
def define_list_update(self):
66+
"""Defines a function to update the nth element of a list and return the updated list.
67+
68+
update(l, i, v) : list[a] -> nat -> a -> list[a]
69+
"""
70+
self.update = GlobalVar("update")
71+
a = TypeVar("a")
72+
l = Var("l", self.l(a))
73+
n = Var("n", self.nat())
74+
v = Var("v", a)
75+
76+
y = Var("y")
77+
78+
z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l)))
79+
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
80+
self.cons(self.hd(l), self.update(self.tl(l), y, v)))
81+
82+
self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a])
83+
6584
def define_list_map(self):
6685
"""Defines a function for mapping a function over a list's
6786
elements. That is, map(f, l) returns a new list where
@@ -470,6 +489,7 @@ def __init__(self, mod):
470489
self.define_nat_add()
471490
self.define_list_length()
472491
self.define_list_nth()
492+
self.define_list_update()
473493
self.define_list_sum()
474494

475495
self.define_tree_adt()

tests/python/relay/test_adt.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
hd = p.hd
2727
tl = p.tl
2828
nth = p.nth
29+
update = p.update
2930
length = p.length
3031
map = p.map
3132
foldl = p.foldl
@@ -148,6 +149,23 @@ def test_nth():
148149

149150
assert got == expected
150151

152+
def test_update():
153+
expected = list(range(10))
154+
l = nil()
155+
# create zero initialized list
156+
for i in range(len(expected)):
157+
l = cons(build_nat(0), l)
158+
159+
# set value
160+
for i, v in enumerate(expected):
161+
l = update(l, build_nat(i), build_nat(v))
162+
163+
got = []
164+
for i in range(len(expected)):
165+
got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
166+
167+
assert got == expected
168+
151169
def test_length():
152170
a = relay.TypeVar("a")
153171
assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])

0 commit comments

Comments
 (0)