Skip to content

Commit 6a176f8

Browse files
authored
Merge pull request #4 from Nymphium/multi-handler
🎉 add nulti-effect handler
2 parents 8557fd2 + f1c6d77 commit 6a176f8

File tree

3 files changed

+133
-47
lines changed

3 files changed

+133
-47
lines changed

spec/async_await.lua

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
-- https://github.com/ocamllabs/ocaml-effects-tutorial/blob/master/sources/solved/async_await.ml
22

33
local eff = require('src/eff')
4-
local Eff, perform, handler = eff.Eff, eff.perform, eff.handler
4+
local Eff, perform, handlers = eff.Eff, eff.perform, eff.handlers
55

66
local imut = require('spec/utils/imut')
77
local ref = require('spec/utils/ref')
@@ -16,15 +16,19 @@ local Done = function(a)
1616
return { a, cls = "done" }
1717
end
1818

19-
local AEff = Eff("AEff")
19+
local Async = Eff("Async")
2020
local async = function(f)
21-
return perform(AEff{ f, cls = "async" })
21+
return perform(Async(f))
2222
end
23+
24+
local Yield = Eff("Yield")
2325
local yield = function()
24-
return perform(AEff{ cls = "yield" })
26+
return perform(Yield())
2527
end
28+
29+
local Await = Eff("Await")
2630
local await = function(p)
27-
return perform(AEff{ p, cls = "await" })
31+
return perform(Await(p))
2832
end
2933

3034
-- queue
@@ -36,14 +40,13 @@ end
3640
local dequeue = function()
3741
local f = table.remove(q, 1)
3842
if f then
39-
local m = f()
40-
return m
43+
return f()
4144
end
4245
end
4346

4447
local run = function(main)
4548
local function fork(pr, main)
46-
return handler(AEff,
49+
return handlers(
4750
function(v)
4851
local pp = pr:get()
4952
local l
@@ -61,17 +64,17 @@ local run = function(main)
6164
pr(Done(v))
6265
return dequeue()
6366
end,
64-
function(k, c)
65-
if c.cls == "async" then
66-
local f = c[1]
67+
68+
{Async, function(k, f)
6769
local pr_ = ref(Waiting{})
6870
enqueue(function() return k(pr_) end)
6971
return fork(pr_, f)
70-
elseif c.cls == "yield" then
72+
end},
73+
{Yield, function(k)
7174
enqueue(function() return k() end)
7275
return dequeue()
73-
elseif c.cls == "await" then
74-
local p = c[1]
76+
end},
77+
{Await, function(k, p)
7578
local pp = p:get()
7679

7780
if pp.cls == "done" then
@@ -80,8 +83,7 @@ local run = function(main)
8083
p(Waiting(imut.cons(k, pp[1])))
8184
return dequeue()
8285
end
83-
end
84-
end)(main)
86+
end})(main)
8587
end
8688

8789
return fork(ref(Waiting{}), main)

spec/state2.lua

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,44 @@
11
-- https://github.com/ocamllabs/ocaml-effects-tutorial/blob/master/sources/solved/state2.ml
22

33
local eff = require('src/eff')
4-
local Eff, perform, handler = eff.Eff, eff.perform, eff.handler
4+
local Eff, perform, handlers = eff.Eff, eff.perform, eff.handlers
55

66
local imut = require('spec/utils/imut')
77

88
local State = function()
9-
local State = Eff("State")
9+
local Get = Eff("Get")
10+
local Put = Eff("Put")
11+
local History = Eff("History")
12+
1013
local get = function()
11-
return perform(State{ cls = "Get" })
14+
return perform(Get())
1215
end
1316
local put = function(v)
14-
return perform(State{ v, cls = "Put" })
17+
return perform(Put(v))
1518
end
1619
local history = function()
17-
return perform(State{ cls = "History" })
20+
return perform(History())
1821
end
1922

2023
local run = function(f, init)
21-
local comp = handler(State,
22-
function() return function() end end,
23-
function(k, c)
24-
if c.cls == "Get" then
25-
return function(s, h)
26-
return k(s)(s, h)
27-
end
28-
elseif c.cls == "Put" then
29-
return function(_, h)
30-
local s_ = c[1]
31-
return k()(s_, imut.cons(s_, h))
32-
end
33-
elseif c.cls == "History" then
34-
return function(s, h)
35-
return k(imut.rev(h))(s, imut.cp(h))
36-
end
37-
end
38-
end)(f)
24+
local comp = handlers(
25+
function() return function() end end,
26+
{Get, function(k)
27+
return function(s, h)
28+
return k(s)(s, h)
29+
end
30+
end},
31+
{Put, function(k, v)
32+
return function(_, h)
33+
return k()(v, imut.cons(v, h))
34+
end
35+
end},
36+
{History, function(k)
37+
return function(s, h)
38+
return k(imut.rev(h))(s, imut.cp(h))
39+
end
40+
end}
41+
)(f)
3942

4043
return comp(init, {})
4144
end

src/eff.lua

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
local create = coroutine.create
22
local resume = coroutine.resume
33
local yield = coroutine.yield
4-
local unpack = table.unpack or unpack
4+
local unpack0 = table.unpack or unpack
5+
6+
local unpack = function(t)
7+
if t and #t > 0 then
8+
return unpack0(t)
9+
end
10+
end
511

612
local Eff
713
do
@@ -59,6 +65,27 @@ local is_eff_obj = function(obj)
5965
return type(obj) == "table" and (obj.cls == Eff.cls or obj.cls == Resend.cls)
6066
end
6167

68+
local function get_effh(eff, effeffhs)
69+
eff = tostring(eff)
70+
71+
for i = 1, #effeffhs do
72+
if tostring(effeffhs[i][1]) == eff then
73+
return effeffhs[i][2]
74+
end
75+
end
76+
end
77+
78+
local function handle_error_message(r)
79+
if type(r) == "string" and
80+
(r:match("attempt to yield from outside a coroutine")
81+
or r:match("cannot resume dead coroutine"))
82+
then
83+
return error("continuation cannot be performed twice")
84+
else
85+
return error(r)
86+
end
87+
end
88+
6289
local handler
6390
handler = function(eff, vh, effh)
6491
local is_the_eff = function(it)
@@ -72,7 +99,7 @@ handler = function(eff, vh, effh)
7299
local continue
73100

74101
local rehandle = function(arg, k)
75-
return handler(eff, function(...) return continue(gr, ...) end, effh)(function()
102+
return handler(eff, function(args) return continue(gr, unpack(args)) end, effh)(function()
76103
return k(arg)
77104
end)
78105
end
@@ -108,14 +135,66 @@ handler = function(eff, vh, effh)
108135
continue = function(co, arg)
109136
local st, r = resume(co, arg)
110137
if not st then
111-
if type(r) == "string" and
112-
(r:match("attempt to yield from outside a coroutine")
113-
or r:match("cannot resume dead coroutine"))
114-
then
115-
return error("continuation cannot be performed twice")
138+
return handle_error_message(r)
139+
else
140+
return handle(r)
141+
end
142+
end
143+
144+
return continue(gr, nil)
145+
end
146+
end
147+
148+
local handlers
149+
handlers = function(vh, ...)
150+
local effeffhs = {...}
151+
152+
return function(th)
153+
local gr = create(th)
154+
155+
local handle
156+
local continue
157+
158+
local rehandles = function(arg, k)
159+
return handlers(function(...) return continue(gr, ...) end, unpack(effeffhs))(function()
160+
return k(arg)
161+
end)
162+
end
163+
164+
handle = function(r)
165+
if not is_eff_obj(r) then
166+
return vh(r)
167+
end
168+
169+
if r.cls == Eff.cls then
170+
local effh = get_effh(r.eff, effeffhs)
171+
if effh then
172+
return effh(function(arg)
173+
return continue(gr, arg)
174+
end, unpack(r.arg))
175+
else
176+
return Resend(r, function(arg)
177+
return continue(gr, arg)
178+
end)
179+
end
180+
elseif r.cls == Resend.cls then
181+
local effh = get_effh(r.eff, effeffhs)
182+
if effh then
183+
return effh(function(arg)
184+
return rehandles(arg, r.continue)
185+
end, unpack(r.arg))
116186
else
117-
return error(r)
187+
return Resend(r, function(arg)
188+
return rehandles(arg, r.continue)
189+
end)
118190
end
191+
end
192+
end
193+
194+
continue = function(co, arg)
195+
local st, r = resume(co, arg)
196+
if not st then
197+
return handle_error_message(r)
119198
else
120199
return handle(r)
121200
end
@@ -125,9 +204,11 @@ handler = function(eff, vh, effh)
125204
end
126205
end
127206

207+
128208
return {
129209
Eff = Eff,
130210
perform = yield,
131211
handler = handler,
212+
handlers = handlers
132213
}
133214

0 commit comments

Comments
 (0)