Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions spec/async_await.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- https://github.com/ocamllabs/ocaml-effects-tutorial/blob/master/sources/solved/async_await.ml

local eff = require('src/eff')
local Eff, perform, handler = eff.Eff, eff.perform, eff.handler
local Eff, perform, handlers = eff.Eff, eff.perform, eff.handlers

local imut = require('spec/utils/imut')
local ref = require('spec/utils/ref')
Expand All @@ -16,15 +16,19 @@ local Done = function(a)
return { a, cls = "done" }
end

local AEff = Eff("AEff")
local Async = Eff("Async")
local async = function(f)
return perform(AEff{ f, cls = "async" })
return perform(Async(f))
end

local Yield = Eff("Yield")
local yield = function()
return perform(AEff{ cls = "yield" })
return perform(Yield())
end

local Await = Eff("Await")
local await = function(p)
return perform(AEff{ p, cls = "await" })
return perform(Await(p))
end

-- queue
Expand All @@ -36,14 +40,13 @@ end
local dequeue = function()
local f = table.remove(q, 1)
if f then
local m = f()
return m
return f()
end
end

local run = function(main)
local function fork(pr, main)
return handler(AEff,
return handlers(
function(v)
local pp = pr:get()
local l
Expand All @@ -61,17 +64,17 @@ local run = function(main)
pr(Done(v))
return dequeue()
end,
function(k, c)
if c.cls == "async" then
local f = c[1]

{Async, function(k, f)
local pr_ = ref(Waiting{})
enqueue(function() return k(pr_) end)
return fork(pr_, f)
elseif c.cls == "yield" then
end},
{Yield, function(k)
enqueue(function() return k() end)
return dequeue()
elseif c.cls == "await" then
local p = c[1]
end},
{Await, function(k, p)
local pp = p:get()

if pp.cls == "done" then
Expand All @@ -80,8 +83,7 @@ local run = function(main)
p(Waiting(imut.cons(k, pp[1])))
return dequeue()
end
end
end)(main)
end})(main)
end

return fork(ref(Waiting{}), main)
Expand Down
49 changes: 26 additions & 23 deletions spec/state2.lua
Original file line number Diff line number Diff line change
@@ -1,41 +1,44 @@
-- https://github.com/ocamllabs/ocaml-effects-tutorial/blob/master/sources/solved/state2.ml

local eff = require('src/eff')
local Eff, perform, handler = eff.Eff, eff.perform, eff.handler
local Eff, perform, handlers = eff.Eff, eff.perform, eff.handlers

local imut = require('spec/utils/imut')

local State = function()
local State = Eff("State")
local Get = Eff("Get")
local Put = Eff("Put")
local History = Eff("History")

local get = function()
return perform(State{ cls = "Get" })
return perform(Get())
end
local put = function(v)
return perform(State{ v, cls = "Put" })
return perform(Put(v))
end
local history = function()
return perform(State{ cls = "History" })
return perform(History())
end

local run = function(f, init)
local comp = handler(State,
function() return function() end end,
function(k, c)
if c.cls == "Get" then
return function(s, h)
return k(s)(s, h)
end
elseif c.cls == "Put" then
return function(_, h)
local s_ = c[1]
return k()(s_, imut.cons(s_, h))
end
elseif c.cls == "History" then
return function(s, h)
return k(imut.rev(h))(s, imut.cp(h))
end
end
end)(f)
local comp = handlers(
function() return function() end end,
{Get, function(k)
return function(s, h)
return k(s)(s, h)
end
end},
{Put, function(k, v)
return function(_, h)
return k()(v, imut.cons(v, h))
end
end},
{History, function(k)
return function(s, h)
return k(imut.rev(h))(s, imut.cp(h))
end
end}
)(f)

return comp(init, {})
end
Expand Down
97 changes: 89 additions & 8 deletions src/eff.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
local create = coroutine.create
local resume = coroutine.resume
local yield = coroutine.yield
local unpack = table.unpack or unpack
local unpack0 = table.unpack or unpack

local unpack = function(t)
if t and #t > 0 then
return unpack0(t)
end
end

local Eff
do
Expand Down Expand Up @@ -59,6 +65,27 @@ local is_eff_obj = function(obj)
return type(obj) == "table" and (obj.cls == Eff.cls or obj.cls == Resend.cls)
end

local function get_effh(eff, effeffhs)
eff = tostring(eff)

for i = 1, #effeffhs do
if tostring(effeffhs[i][1]) == eff then
return effeffhs[i][2]
end
end
end

local function handle_error_message(r)
if type(r) == "string" and
(r:match("attempt to yield from outside a coroutine")
or r:match("cannot resume dead coroutine"))
then
return error("continuation cannot be performed twice")
else
return error(r)
end
end

local handler
handler = function(eff, vh, effh)
local is_the_eff = function(it)
Expand All @@ -72,7 +99,7 @@ handler = function(eff, vh, effh)
local continue

local rehandle = function(arg, k)
return handler(eff, function(...) return continue(gr, ...) end, effh)(function()
return handler(eff, function(args) return continue(gr, unpack(args)) end, effh)(function()
return k(arg)
end)
end
Expand Down Expand Up @@ -108,14 +135,66 @@ handler = function(eff, vh, effh)
continue = function(co, arg)
local st, r = resume(co, arg)
if not st then
if type(r) == "string" and
(r:match("attempt to yield from outside a coroutine")
or r:match("cannot resume dead coroutine"))
then
return error("continuation cannot be performed twice")
return handle_error_message(r)
else
return handle(r)
end
end

return continue(gr, nil)
end
end

local handlers
handlers = function(vh, ...)
local effeffhs = {...}

return function(th)
local gr = create(th)

local handle
local continue

local rehandles = function(arg, k)
return handlers(function(...) return continue(gr, ...) end, unpack(effeffhs))(function()
return k(arg)
end)
end

handle = function(r)
if not is_eff_obj(r) then
return vh(r)
end

if r.cls == Eff.cls then
local effh = get_effh(r.eff, effeffhs)
if effh then
return effh(function(arg)
return continue(gr, arg)
end, unpack(r.arg))
else
return Resend(r, function(arg)
return continue(gr, arg)
end)
end
elseif r.cls == Resend.cls then
local effh = get_effh(r.eff, effeffhs)
if effh then
return effh(function(arg)
return rehandles(arg, r.continue)
end, unpack(r.arg))
else
return error(r)
return Resend(r, function(arg)
return rehandles(arg, r.continue)
end)
end
end
end

continue = function(co, arg)
local st, r = resume(co, arg)
if not st then
return handle_error_message(r)
else
return handle(r)
end
Expand All @@ -125,9 +204,11 @@ handler = function(eff, vh, effh)
end
end


return {
Eff = Eff,
perform = yield,
handler = handler,
handlers = handlers
}