|
| 1 | +local co = coroutine |
| 2 | +local errors = require('plenary.errors') |
| 3 | +local traceback_error = errors.traceback_error |
| 4 | + |
| 5 | +local M = {} |
| 6 | + |
| 7 | +---@class Future |
| 8 | +---Something that will give a value when run |
| 9 | + |
| 10 | +---Executes a future with a callback when it is done |
| 11 | +---@param future Future: the future to execute |
| 12 | +---@param callback function: the callback to call when done |
| 13 | +local execute = function(future, callback) |
| 14 | + assert(type(future) == "function", "type error :: expected func") |
| 15 | + local thread = co.create(future) |
| 16 | + |
| 17 | + local step |
| 18 | + step = function(...) |
| 19 | + local res = {co.resume(thread, ...)} |
| 20 | + local stat = res[1] |
| 21 | + local ret = {select(2, unpack(res))} |
| 22 | + |
| 23 | + if not stat then |
| 24 | + error(string.format("The coroutine failed with this message: %s", ret[1])) |
| 25 | + end |
| 26 | + |
| 27 | + if co.status(thread) == "dead" then |
| 28 | + (callback or function() end)(unpack(ret)) |
| 29 | + else |
| 30 | + assert(#ret == 1, "expected a single return value") |
| 31 | + local returned_future = ret[1] |
| 32 | + assert(type(returned_future) == "function", "type error :: expected func") |
| 33 | + returned_future(step) |
| 34 | + end |
| 35 | + end |
| 36 | + |
| 37 | + step() |
| 38 | +end |
| 39 | + |
| 40 | +---Creates an async function with a callback style function. |
| 41 | +---@param func function: A callback style function to be converted. The last argument must be the callback. |
| 42 | +---@param argc number: The number of arguments of func. Must be included. |
| 43 | +---@return function: Returns an async function |
| 44 | +M.wrap = function(func, argc) |
| 45 | + if type(func) ~= "function" then |
| 46 | + traceback_error("type error :: expected func, got " .. type(func)) |
| 47 | + end |
| 48 | + |
| 49 | + if type(argc) ~= "number" and argc ~= "vararg" then |
| 50 | + traceback_error("expected argc to be a number or string literal 'vararg'") |
| 51 | + end |
| 52 | + |
| 53 | + return function(...) |
| 54 | + local params = {...} |
| 55 | + |
| 56 | + local function future(step) |
| 57 | + if step then |
| 58 | + if type(argc) == "number" then |
| 59 | + params[argc] = step |
| 60 | + else |
| 61 | + table.insert(params, step) -- change once not optional |
| 62 | + end |
| 63 | + return func(unpack(params)) |
| 64 | + else |
| 65 | + return co.yield(future) |
| 66 | + end |
| 67 | + end |
| 68 | + return future |
| 69 | + end |
| 70 | +end |
| 71 | + |
| 72 | +---Return a new future that when run will run all futures concurrently. |
| 73 | +---@param futures table: the futures that you want to join |
| 74 | +---@return Future: returns a future |
| 75 | +M.join = M.wrap(function(futures, step) |
| 76 | + local len = #futures |
| 77 | + local results = {} |
| 78 | + local done = 0 |
| 79 | + |
| 80 | + if len == 0 then |
| 81 | + return step(results) |
| 82 | + end |
| 83 | + |
| 84 | + for i, future in ipairs(futures) do |
| 85 | + assert(type(future) == "function", "type error :: future must be function") |
| 86 | + |
| 87 | + local callback = function(...) |
| 88 | + results[i] = {...} |
| 89 | + done = done + 1 |
| 90 | + if done == len then |
| 91 | + step(results) |
| 92 | + end |
| 93 | + end |
| 94 | + |
| 95 | + future(callback) |
| 96 | + end |
| 97 | +end, 2) |
| 98 | + |
| 99 | +---Returns a future that when run will select the first future that finishes |
| 100 | +---@param futures table: The future that you want to select |
| 101 | +---@return Future |
| 102 | +M.select = M.wrap(function(futures, step) |
| 103 | + local selected = false |
| 104 | + |
| 105 | + for _, future in ipairs(futures) do |
| 106 | + assert(type(future) == "function", "type error :: future must be function") |
| 107 | + |
| 108 | + local callback = function(...) |
| 109 | + if not selected then |
| 110 | + selected = true |
| 111 | + step(...) |
| 112 | + end |
| 113 | + end |
| 114 | + |
| 115 | + future(callback) |
| 116 | + end |
| 117 | +end, 2) |
| 118 | + |
| 119 | +---Use this to either run a future concurrently and then do something else |
| 120 | +---or use it to run a future with a callback in a non async context |
| 121 | +---@param future Future |
| 122 | +---@param callback function |
| 123 | +M.run = function(future, callback) |
| 124 | + future(callback or function() end) |
| 125 | +end |
| 126 | + |
| 127 | +---Same as run but runs multiple futures |
| 128 | +---@param futures table |
| 129 | +---@param callback function |
| 130 | +M.run_all = function(futures, callback) |
| 131 | + M.run(M.join(futures), callback) |
| 132 | +end |
| 133 | + |
| 134 | +---Await a future, yielding the current function |
| 135 | +---@param future Future |
| 136 | +---@return any: returns the result of the future when it is done |
| 137 | +M.await = function(future) |
| 138 | + assert(type(future) == "function", "type error :: expected function to await") |
| 139 | + return future(nil) |
| 140 | +end |
| 141 | + |
| 142 | +---Same as await but can await multiple futures. |
| 143 | +---If the futures have libuv leaf futures they will be run concurrently |
| 144 | +---@param futures table |
| 145 | +---@return table: returns a table of results that each future returned. Note that if the future returns multiple values they will be packed into a table. |
| 146 | +M.await_all = function(futures) |
| 147 | + assert(type(futures) == "table", "type error :: expected table") |
| 148 | + return M.await(M.join(futures)) |
| 149 | +end |
| 150 | + |
| 151 | +---suspend a coroutine |
| 152 | +M.suspend = co.yield |
| 153 | + |
| 154 | +---create a async scope |
| 155 | +M.scope = function(func) |
| 156 | + M.run(M.future(func)) |
| 157 | +end |
| 158 | + |
| 159 | +--- Future a :: a -> (a -> ()) |
| 160 | +--- turns this signature |
| 161 | +--- ... -> Future a |
| 162 | +--- into this signature |
| 163 | +--- ... -> () |
| 164 | +M.void = function(async_func) |
| 165 | + return function(...) |
| 166 | + async_func(...)(function() end) |
| 167 | + end |
| 168 | +end |
| 169 | + |
| 170 | +---creates an async function |
| 171 | +---@param func function |
| 172 | +---@return function: returns an async function |
| 173 | +M.async = function(func) |
| 174 | + if type(func) ~= "function" then |
| 175 | + traceback_error("type error :: expected func, got " .. type(func)) |
| 176 | + end |
| 177 | + |
| 178 | + return function(...) |
| 179 | + local args = {...} |
| 180 | + local function future(step) |
| 181 | + if step == nil then |
| 182 | + return func(unpack(args)) |
| 183 | + else |
| 184 | + execute(future, step) |
| 185 | + end |
| 186 | + end |
| 187 | + return future |
| 188 | + end |
| 189 | +end |
| 190 | + |
| 191 | +---creates a future |
| 192 | +---@param func function |
| 193 | +---@return Future |
| 194 | +M.future = function(func) |
| 195 | + return M.async(func)() |
| 196 | +end |
| 197 | + |
| 198 | +---An async function that when awaited will await the scheduler to be able to call the api. |
| 199 | +M.scheduler = M.wrap(vim.schedule, 1) |
| 200 | + |
| 201 | +return M |
0 commit comments