Skip to content

Commit 5f1552b

Browse files
committed
use a compact bindings, remove the offset limitation
1 parent 9675907 commit 5f1552b

File tree

2 files changed

+86
-75
lines changed

2 files changed

+86
-75
lines changed

src/tapedfunction.jl

Lines changed: 83 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
## Instruction and TapedFunction
22

3+
# TODO add an introduction of the impl, caveats
4+
# caveates:
5+
# 1. global references are cached out (info msg)
6+
# 2. QuoteNode is evaluated at tape-recording time (compile time) (info msg)
7+
# 3. One allocation in each Instruction
8+
39
abstract type AbstractInstruction end
410
const RawTape = Vector{AbstractInstruction}
511

@@ -18,6 +24,7 @@ mutable struct TapedFunction{F, TapeType}
1824
tape::TapeType
1925
counter::Int
2026
bindings::Bindings
27+
slots::Dict{Int, Int} # slots indices in bindings
2128
retval::Int # 0 indicates the function has not returned
2229

2330
function TapedFunction{F, T}(f::F, args...; cache=false) where {F, T}
@@ -31,9 +38,9 @@ mutable struct TapedFunction{F, TapeType}
3138
return tf
3239
end
3340
ir = _infer(f, args_type)
34-
bindings, tape = translate!(RawTape(), ir)
41+
bindings, slots, tape = translate!(RawTape(), ir)
3542

36-
tf = new{F, T}(f, length(args), ir, tape, 1, bindings, 0)
43+
tf = new{F, T}(f, length(args), ir, tape, 1, bindings, slots, 0)
3744
TRCache[cache_key] = tf # set cache
3845
return tf
3946
end
@@ -43,7 +50,7 @@ mutable struct TapedFunction{F, TapeType}
4350

4451
function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1}
4552
new{F, T0}(tf.func, tf.arity, tf.ir, tf.tape,
46-
tf.counter, tf.bindings, 0)
53+
tf.counter, tf.bindings, tf.slots, 0)
4754
end
4855

4956
TapedFunction(tf::TapedFunction{F, T}) where {F, T} = TapedFunction{F, T}(tf)
@@ -104,10 +111,10 @@ function (tf::TapedFunction)(args...; callback=nothing, continuation=false)
104111
# set args
105112
if tf.counter <= 1
106113
# The first slot in `bindings` is assumed to be `tf.func`.
107-
_update_var!(tf, 1, tf.func)
114+
haskey(tf.slots, 1) && _update_var!(tf, tf.slots[1], tf.func)
108115
for i in 1:length(args) # the subsequent slots are arguments
109116
slot = i + 1
110-
_update_var!(tf, slot, args[i])
117+
haskey(tf.slots, slot) && _update_var!(tf, tf.slots[slot], args[i])
111118
end
112119
end
113120

@@ -218,128 +225,132 @@ end
218225

219226

220227
## Translation: CodeInfo -> Tape
221-
#=
222-
Now, we use a Vector{Any} as the bindings to store all the
223-
variables used in a taped function. These variables are categorized
224-
into 3 kinds: 1. slot values, 2. literal data, 3. ssa values.
225-
226-
- bindings[1:CNT_SLOT-1] is for slot values
227-
- bindings[CNT_SLOT] holds the current used maximum literal index
228-
- bindings[CNT_SLOT+1:CNT_SLOT+CNT_LITE] is for literal data
229-
- bindings[CNT_SLOT+CNT_LITE+1:end] is for ssa values
230-
=#
231-
const CNT_SLOT = 200
232-
const CNT_LITE = 1000
233-
const OFFSET_VAR = CNT_SLOT + CNT_LITE
234-
235-
function bind_var!(var_literal, bindings::Bindings, ir::Core.CodeInfo) # for literal constants
236-
last_idx = bindings[CNT_SLOT]
237-
idx = last_idx + 1
238-
@assert idx < OFFSET_VAR
239-
bindings[CNT_SLOT] = idx
240-
bindings[idx] = var_literal
228+
229+
struct TempBindings
230+
data::Bindings
231+
book::Dict{Any, Int}
232+
end
233+
234+
function bind_var!(var_literal, tbind::TempBindings, ir::Core.CodeInfo)
235+
# for literal constants
236+
push!(tbind.data, var_literal)
237+
idx = length(tbind.data)
241238
return idx
242239
end
243-
bind_var!(var::GlobalRef, bindings::Bindings, ir::Core.CodeInfo) =
244-
bind_var!(getproperty(var.mod, var.name), bindings, ir)
245-
bind_var!(var::QuoteNode, bindings::Bindings, ir::Core.CodeInfo) =
246-
bind_var!(eval(var), bindings, ir) # staging out value of `var::QuoteNode`
247-
bind_var!(var::Core.TypedSlot, bindings::Bindings, ir::Core.CodeInfo) =
248-
(@assert var.id < CNT_SLOT; bind_var!(var.id, bindings, ir.slottypes[var.id]))
249-
bind_var!(var::Core.SlotNumber, bindings::Bindings, ir::Core.CodeInfo) =
250-
(@assert var.id < CNT_SLOT; bind_var!(var.id, bindings, ir.slottypes[var.id]))
251-
bind_var!(var::Core.SSAValue, bindings::Bindings, ir::Core.CodeInfo) =
252-
bind_var!(var.id + OFFSET_VAR, bindings, ir.ssavaluetypes[var.id])
253-
bind_var!(var::Int, boxes::Bindings, c::Core.Const) =
254-
bind_var!(var, boxes, _loose_type(Type{c.val}))
255-
bind_var!(var::Int, boxes::Bindings, c::Core.PartialStruct) =
256-
bind_var!(var, boxes, _loose_type(c.typ))
257-
function bind_var!(var::Int, bindings::Bindings, ::Type{T}) where T
258-
# here var is the unified index
259-
var > length(bindings) && resize!(bindings, var + 10)
260-
return var
240+
function bind_var!(var::GlobalRef, tbind::TempBindings, ir::Core.CodeInfo)
241+
in(var.mod, (Base, Core)) ||
242+
@info "evaluating GlobalRef $var at compile time"
243+
bind_var!(getproperty(var.mod, var.name), tbind, ir)
244+
end
245+
function bind_var!(var::QuoteNode, tbind::TempBindings, ir::Core.CodeInfo)
246+
@info "evaluating QuoteNode $var at compile time"
247+
bind_var!(eval(var), tbind, ir)
248+
end
249+
function bind_var!(var::Core.TypedSlot, tbind::TempBindings, ir::Core.CodeInfo)
250+
get!(tbind.book, var, allocate_binding!(var, tbind, ir.slottypes[var.id]))
251+
end
252+
function bind_var!(var::Core.SlotNumber, tbind::TempBindings, ir::Core.CodeInfo)
253+
get!(tbind.book, var, allocate_binding!(var, tbind, ir.slottypes[var.id]))
254+
end
255+
function bind_var!(var::Core.SSAValue, tbind::TempBindings, ir::Core.CodeInfo)
256+
get!(tbind.book, var, allocate_binding!(var, tbind, ir.ssavaluetypes[var.id]))
257+
end
258+
259+
allocate_binding!(var, tbind::TempBindings, c::Core.Const) =
260+
allocate_binding!(var, tbind, _loose_type(Type{c.val}))
261+
allocate_binding!(var, tbind::TempBindings, c::Core.PartialStruct) =
262+
allocate_binding!(var, tbind, _loose_type(c.typ))
263+
function allocate_binding!(var, tbind::TempBindings, ::Type{T}) where T
264+
# we may use the type info (T) here
265+
push!(tbind.data, nothing)
266+
idx = length(tbind.data)
267+
return idx
261268
end
262269

263270
function translate!(tape::RawTape, ir::Core.CodeInfo)
264271
bindings = Bindings()
265-
resize!(bindings, OFFSET_VAR + 10)
266-
bindings[CNT_SLOT] = CNT_SLOT
272+
bcache = Dict{Any, Int}()
273+
tbind = TempBindings(bindings, bcache)
274+
slots = Dict{Int, Int}()
267275

268276
for (idx, line) in enumerate(ir.code)
269277
isa(line, Core.Const) && (line = line.val) # unbox Core.Const
270278
isconst = isa(ir.ssavaluetypes[idx], Core.Const)
271-
ins = translate!!(Core.SSAValue(idx), line, bindings, isconst, ir)
279+
ins = translate!!(Core.SSAValue(idx), line, tbind, isconst, ir)
272280
push!(tape, ins)
273281
end
274-
return (bindings, tape)
282+
for (k, v) in bcache
283+
isa(k, Union{Core.TypedSlot, Core.SlotNumber}) && (slots[k.id] = v)
284+
end
285+
return (bindings, slots, tape)
275286
end
276287

277288
const IRVar = Union{Core.SSAValue, Core.SlotNumber}
278289

279-
function _const_instruction(var::IRVar, v, bindings::Bindings, ir)
290+
function _const_instruction(var::IRVar, v, tbind::TempBindings, ir)
280291
if isa(var, Core.SSAValue)
281-
box = bind_var!(var, bindings, ir)
282-
bindings[box] = v
292+
box = bind_var!(var, tbind, ir)
293+
tbind.data[box] = v
283294
return NOOPInstruction()
284295
end
285-
return Instruction(identity, (bind_var!(v, bindings, ir),), bind_var!(var, bindings, ir))
296+
return Instruction(identity, (bind_var!(v, tbind, ir),), bind_var!(var, tbind, ir))
286297
end
287298

288299
function translate!!(var::IRVar, line::Core.NewvarNode,
289-
bindings::Bindings, isconst::Bool, @nospecialize(ir))
300+
tbind::TempBindings, isconst::Bool, @nospecialize(ir))
290301
# use a no-op to ensure the 1-to-1 mapping from ir.code to instructions on tape.
291302
return NOOPInstruction()
292303
end
293304

294305
function translate!!(var::IRVar, line::GlobalRef,
295-
bindings::Bindings, isconst::Bool, ir)
306+
tbind::TempBindings, isconst::Bool, ir)
296307
if isconst
297308
v = ir.ssavaluetypes[var.id].val
298-
return _const_instruction(var, v, bindings, ir)
309+
return _const_instruction(var, v, tbind, ir)
299310
end
300311
func() = getproperty(line.mod, line.name)
301-
return Instruction(func, (), bind_var!(var, bindings, ir))
312+
return Instruction(func, (), bind_var!(var, tbind, ir))
302313
end
303314

304315
function translate!!(var::IRVar, line::Core.SlotNumber,
305-
bindings::Bindings, isconst::Bool, ir)
316+
tbind::TempBindings, isconst::Bool, ir)
306317
if isconst
307318
v = ir.ssavaluetypes[var.id].val
308-
return _const_instruction(var, v, bindings, ir)
319+
return _const_instruction(var, v, tbind, ir)
309320
end
310321
func = identity
311-
input = (bind_var!(line, bindings, ir),)
312-
output = bind_var!(var, bindings, ir)
322+
input = (bind_var!(line, tbind, ir),)
323+
output = bind_var!(var, tbind, ir)
313324
return Instruction(func, input, output)
314325
end
315326

316327
function translate!!(var::IRVar, line::Core.TypedSlot,
317-
bindings::Bindings, isconst::Bool, ir)
318-
input_box = bind_var!(Core.SlotNumber(line.id), bindings, ir)
319-
return Instruction(identity, (input_box,), bind_var!(var, bindings, ir))
328+
tbind::TempBindings, isconst::Bool, ir)
329+
input_box = bind_var!(Core.SlotNumber(line.id), tbind, ir)
330+
return Instruction(identity, (input_box,), bind_var!(var, tbind, ir))
320331
end
321332

322333
function translate!!(var::IRVar, line::Core.GotoIfNot,
323-
bindings::Bindings, isconst::Bool, ir)
324-
cond = bind_var!(line.cond, bindings, ir)
334+
tbind::TempBindings, isconst::Bool, ir)
335+
cond = bind_var!(line.cond, tbind, ir)
325336
return CondGotoInstruction(cond, line.dest)
326337
end
327338

328339
function translate!!(var::IRVar, line::Core.GotoNode,
329-
bindings::Bindings, isconst::Bool, @nospecialize(ir))
340+
tbind::TempBindings, isconst::Bool, @nospecialize(ir))
330341
return GotoInstruction(line.label)
331342
end
332343

333344
function translate!!(var::IRVar, line::Core.ReturnNode,
334-
bindings::Bindings, isconst::Bool, ir)
335-
return ReturnInstruction(bind_var!(line.val, bindings, ir))
345+
tbind::TempBindings, isconst::Bool, ir)
346+
return ReturnInstruction(bind_var!(line.val, tbind, ir))
336347
end
337348

338349
_canbeoptimized(v) = isa(v, DataType) || isprimitivetype(typeof(v))
339350
function translate!!(var::IRVar, line::Expr,
340-
bindings::Bindings, isconst::Bool, ir::Core.CodeInfo)
351+
tbind::TempBindings, isconst::Bool, ir::Core.CodeInfo)
341352
head = line.head
342-
_bind_fn = (x) -> bind_var!(x, bindings, ir)
353+
_bind_fn = (x) -> bind_var!(x, tbind, ir)
343354
if head === :new
344355
args = map(_bind_fn, line.args)
345356
return Instruction(__new__, args |> Tuple, _bind_fn(var))
@@ -349,7 +360,7 @@ function translate!!(var::IRVar, line::Expr,
349360
# optimised function calls, we will evaluate the function at compile-time and cache results.
350361
if isconst
351362
v = ir.ssavaluetypes[var.id].val
352-
_canbeoptimized(v) && return _const_instruction(var, v, bindings, ir)
363+
_canbeoptimized(v) && return _const_instruction(var, v, tbind, ir)
353364
end
354365
args = map(_bind_fn, line.args)
355366
# args[1] is the function
@@ -367,7 +378,7 @@ function translate!!(var::IRVar, line::Expr,
367378
lhs = line.args[1]
368379
rhs = line.args[2] # the right hand side, maybe a Expr, or a var, or ...
369380
if Meta.isexpr(rhs, (:new, :call))
370-
return translate!!(lhs, rhs, bindings, false, ir)
381+
return translate!!(lhs, rhs, tbind, false, ir)
371382
else # rhs is a single value
372383
if isconst
373384
v = ir.ssavaluetypes[var.id].val
@@ -381,7 +392,7 @@ function translate!!(var::IRVar, line::Expr,
381392
end
382393
end
383394

384-
function translate!!(var, line, bindings, ir)
395+
function translate!!(var, line, tbind, ir)
385396
@error "Unknown IR code: " typeof(var) var typeof(line) line
386397
throw(ErrorException("Unknown IR code"))
387398
end

src/tapedtask.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ function Base.copy(t::TapedTask; args=())
171171
args
172172
else
173173
if t.tf.counter > 1
174-
# the task is running, we find the
175-
# real args from the copied bindings
176-
tf.bindings[2:(length(t.args) + 1)]
174+
# the task is running, we can find the used args with tf.slots and tf.bindinds,
175+
# but we may lost the unused args and it doesn't matter
176+
[]
177177
else
178178
# the task is not started yet, but no args is given
179179
tape_copy.(t.args)

0 commit comments

Comments
 (0)