-
-
Notifications
You must be signed in to change notification settings - Fork 193
/
Copy pathasync.lua
215 lines (183 loc) · 5.53 KB
/
async.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
local M = {}
--- @class Gitsigns.AsyncTask
--- @field _current Gitsigns.AsyncTask
local Task = {}
-- Handle for an object currently running on the event loop.
-- The coroutine is paused while this is active.
-- Must provide methods close() and is_closing()
--
-- Handle gets updated on each call to a wrapped functions, so provide access
-- to it via a proxy
--- Store all the async threads in a weak table so we don't prevent them from
--- being garbage collected
--- @type table<thread,Gitsigns.AsyncTask>
local handles = setmetatable({}, { __mode = 'k' })
--- Returns whether the current execution context is async.
local function running()
local current = coroutine.running()
return current and handles[current] ~= nil
end
--- @param handle any
--- @return boolean
local function is_Async_T(handle)
return handle
and type(handle) == 'table'
and vim.is_callable(handle.close)
and vim.is_callable(handle.is_closing)
end
--- Analogous to uv.close
--- @param cb function
function Task:close(cb)
-- Close anything running on the event loop
if self._current and not self._current:is_closing() then
self._current:close(cb)
end
end
--- @param co thread
--- @return Gitsigns.AsyncTask
function Task.new(co)
local handle = setmetatable({}, { __index = Task })
handles[co] = handle
return handle
end
--- Analogous to uv.is_closing
--- @return boolean
function Task:is_closing()
return self._current and self._current:is_closing()
end
--- @param func function
--- @param callback? fun(...: any)
--- @param ... any
--- @return Gitsigns.AsyncTask
local function run(func, callback, ...)
local co = coroutine.create(func)
local handle = Task.new(co)
local function step(...)
local ret = { coroutine.resume(co, ...) }
local stat = ret[1]
if not stat then
local co_err = ret[2] --- @type string
error(debug.traceback(co, string.format('The async coroutine failed: %s', co_err)))
elseif coroutine.status(co) == 'dead' then
if callback then
callback(unpack(ret, 2, table.maxn(ret)))
end
else
--- @type fun(...: any): any
local fn = ret[2]
assert(type(fn) == 'function', 'type error :: expected func')
local r = fn(step)
if is_Async_T(r) then
--- @cast r Gitsigns.AsyncTask
handle._current = r
end
end
end
step(...)
return handle
end
--- Must be called from an async context.
--- @param argc integer
--- @param func function
--- @param ... any
--- @return any ...
function M.await(argc, func, ...)
assert(running(), 'Not in an async context')
local args, nargs = { ... }, select('#', ...)
-- Always run the wrapped functions in xpcall and re-raise the error in the
-- coroutine. This makes pcall work as normal.
local stat, ret = coroutine.yield(function(callback)
args[argc] = function(...)
callback(true, { ... })
end
nargs = math.max(nargs, argc)
xpcall(func, function(err)
callback(false, { err, debug.traceback() })
end, unpack(args, 1, nargs))
end)
if not stat then
--- @type string, string
local err, traceback = ret[1], ret[2]
error(string.format('Wrapped function failed: %s\n%s', err, traceback))
end
return unpack(ret, 1, table.maxn(ret))
end
--- @param argc integer
--- @param func function
--- @param ... any
--- @return any ...
function M.wait_sync(argc, func, ...)
local nargs, args = select('#', ...), { ... }
local done = false
local ret = nil
args[argc] = function(...)
ret = { ... }
done = true
end
nargs = math.max(nargs, argc)
func(unpack(args, 1, nargs))
vim.wait(10000, function()
return done
end)
if not done then
error('Timeout waiting for async function')
end
assert(ret)
return unpack(ret, 1, table.maxn(ret))
end
--- @generic R
--- @param argc integer The number of arguments of func. Must be included.
--- @param func fun(): R
--- @param ... any
--- @return R
function M.sync(argc, func, ...)
return M.wait_sync(argc + 1, M.create(argc, func), ...)
end
--- Creates an async function with a callback style function.
--- @param argc integer The number of arguments of func. Must be included.
--- @param func function A callback style function to be converted. The last argument must be the callback.
--- @return function: Returns an async function
function M.awrap(argc, func)
assert(type(func) == 'function')
assert(type(argc) == 'number')
return function(...)
return M.await(argc, func, ...)
end
end
--- create([argc, ] func)
---
--- Use this to create a function which executes in an async context but
--- called from a non-async context. Inherently this cannot return anything
--- since it is non-blocking
---
--- If argc is not provided, then the created async function cannot be continued
---
--- @generic F: function
--- @param argc_or_func F|integer
--- @param func? F
--- @return F
function M.create(argc_or_func, func)
local argc --- @type integer
if type(argc_or_func) == 'function' then
assert(not func)
func = argc_or_func
elseif type(argc_or_func) == 'number' then
assert(type(func) == 'function')
argc = argc_or_func
end
--- @cast func function
--- @param ... any
--- @return any ...
return function(...)
local callback = argc and select(argc + 1, ...) or nil
assert(not callback or type(callback) == 'function')
return run(func, callback, unpack({ ... }, 1, argc))
end
end
--- An async function that when called will yield to the Neovim scheduler to be
--- able to call the API.
M.scheduler = M.awrap(1, vim.schedule)
function M.run(func, ...)
return run(func, nil, ...)
end
return M