Skip to content

Commit 2a8eb6e

Browse files
committed
Record captured locals per lambda
1 parent afce50e commit 2a8eb6e

File tree

3 files changed

+125
-33
lines changed

3 files changed

+125
-33
lines changed

src/closure_conversion.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct ClosureConversionCtx{GraphType} <: AbstractLoweringContext
66
end
77

88
function add_lambda_local!(ctx::ClosureConversionCtx, id)
9-
push!(ctx.lambda_bindings.locals, id)
9+
init_lambda_binding(ctx.lambda_bindings, id)
1010
end
1111

1212
# Convert `ex` to `type` by calling `convert(type, ex)` when necessary.

src/linear_ir.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ function compile_conditional(ctx, ex, false_label)
399399
end
400400

401401
function add_lambda_local!(ctx::LinearIRContext, id)
402-
push!(ctx.lambda_bindings.locals, id)
402+
init_lambda_binding(ctx.lambda_bindings, id)
403403
end
404404

405405
# Lowering of exception handling must ensure that
@@ -941,11 +941,14 @@ function compile_lambda(outer_ctx, ex)
941941
end
942942
end
943943
# Sorting the lambda locals is required to remove dependence on Dict iteration order.
944-
for id in sort(collect(ex.lambda_bindings.locals))
945-
info = lookup_binding(ctx.bindings, id)
946-
@assert info.kind == :local
947-
push!(slots, Slot(info.name, :local, false))
948-
slot_rewrites[id] = length(slots)
944+
for (id, lbinfo) in sort(collect(pairs(ex.lambda_bindings.bindings)), by=first)
945+
if !lbinfo.is_captured
946+
info = lookup_binding(ctx.bindings, id)
947+
if info.kind == :local
948+
push!(slots, Slot(info.name, :local, false))
949+
slot_rewrites[id] = length(slots)
950+
end
951+
end
949952
end
950953
for (i,arg) in enumerate(children(static_parameters))
951954
@assert kind(arg) == K"BindingId"

src/scope_analysis.jl

Lines changed: 115 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,49 @@ function NameKey(ex::SyntaxTree)
123123
NameKey(ex.name_val, get(ex, :scope_layer, _lowering_internal_layer))
124124
end
125125

126-
struct CaptureInfo
126+
# Metadata about how a binding is used within some enclosing lambda
127+
struct LambdaBindingInfo
128+
is_captured::Bool
127129
is_read::Bool
128130
is_assigned::Bool
129131
is_called::Bool
130132
end
131133

134+
LambdaBindingInfo() = LambdaBindingInfo(false, false, false, false)
135+
136+
function LambdaBindingInfo(parent::LambdaBindingInfo;
137+
is_captured = nothing,
138+
is_read = nothing,
139+
is_assigned = nothing,
140+
is_called = nothing)
141+
LambdaBindingInfo(
142+
isnothing(is_captured) ? parent.is_captured : is_captured,
143+
isnothing(is_read) ? parent.is_read : is_read,
144+
isnothing(is_assigned) ? parent.is_assigned : is_assigned,
145+
isnothing(is_called) ? parent.is_called : is_called,
146+
)
147+
end
148+
132149
struct LambdaBindings
133-
# Local bindings within the lambda
134-
locals::Set{IdTag}
135-
captures::Dict{IdTag,CaptureInfo}
150+
# Bindings used within the lambda
151+
bindings::Dict{IdTag,LambdaBindingInfo}
136152
end
137153

138-
LambdaBindings() = LambdaBindings(Set{IdTag}(), Dict{IdTag,CaptureInfo}())
154+
function init_lambda_binding(binds::LambdaBindings, id; kws...)
155+
@assert !haskey(binds.bindings, id)
156+
binds.bindings[id] = LambdaBindingInfo(LambdaBindingInfo(); kws...)
157+
end
158+
159+
function update_lambda_binding!(binds::LambdaBindings, id; kws...)
160+
binfo = binds.bindings[id]
161+
binds.bindings[id] = LambdaBindingInfo(binfo; kws...)
162+
end
163+
164+
function update_lambda_binding!(ctx::AbstractLoweringContext, id; kws...)
165+
update_lambda_binding!(last(ctx.scope_stack).lambda_bindings, id; kws...)
166+
end
167+
168+
LambdaBindings() = LambdaBindings(Dict{IdTag,LambdaBindings}())
139169

140170

141171
struct ScopeInfo
@@ -230,17 +260,19 @@ function add_lambda_args(ctx, var_ids, args, args_kind)
230260
"static parameter name not distinct from function argument"
231261
throw(LoweringError(arg, msg))
232262
end
233-
var_ids[varkey] = init_binding(ctx, varkey, args_kind;
234-
is_nospecialize=getmeta(arg, :nospecialize, false))
263+
id = init_binding(ctx, varkey, args_kind;
264+
is_nospecialize=getmeta(arg, :nospecialize, false))
265+
var_ids[varkey] = id
235266
elseif ka != K"BindingId" && ka != K"Placeholder"
236267
throw(LoweringError(arg, "Unexpected lambda arg kind"))
237268
end
238269
end
239270
end
240271

241-
# Analyze identifier usage within a scope, adding all newly discovered
242-
# identifiers to ctx.bindings and returning a lookup table from identifier
243-
# names to their variable IDs
272+
# Analyze identifier usage within a scope
273+
# * Allocate a new binding for each identifier which the scope introduces.
274+
# * Record the identifier=>binding mapping in a lookup table
275+
# * Return a `ScopeInfo` with the mapping plus additional scope metadata
244276
function analyze_scope(ctx, ex, scope_type, is_toplevel_global_scope=false,
245277
lambda_args=nothing, lambda_static_parameters=nothing)
246278
parentscope = isempty(ctx.scope_stack) ? nothing : ctx.scope_stack[end]
@@ -251,8 +283,13 @@ function analyze_scope(ctx, ex, scope_type, is_toplevel_global_scope=false,
251283
assignments, locals, destructured_args, globals,
252284
used, used_bindings, alias_bindings = find_scope_vars(ex)
253285

254-
# Create new lookup table for variables in this scope which differ from the
255-
# parent scope.
286+
# Construct a mapping from identifiers to bindings
287+
#
288+
# This will contain a binding ID for each variable which is introduced by
289+
# the scope, including
290+
# * Explicit locals
291+
# * Explicit globals
292+
# * Implicit locals created by assignment
256293
var_ids = Dict{NameKey,IdTag}()
257294

258295
if !isnothing(lambda_args)
@@ -272,8 +309,9 @@ function analyze_scope(ctx, ex, scope_type, is_toplevel_global_scope=false,
272309
end
273310
elseif var_kind(ctx, varkey) === :static_parameter
274311
throw(LoweringError(e, "local variable name `$(varkey.name)` conflicts with a static parameter"))
312+
else
313+
var_ids[varkey] = init_binding(ctx, varkey, :local)
275314
end
276-
var_ids[varkey] = init_binding(ctx, varkey, :local)
277315
end
278316

279317
# Add explicit globals
@@ -353,24 +391,67 @@ function analyze_scope(ctx, ex, scope_type, is_toplevel_global_scope=false,
353391
end
354392
end
355393

356-
for varkey in used
357-
if lookup_var(ctx, varkey) === nothing
358-
# Add other newly discovered identifiers as globals
359-
init_binding(ctx, varkey, :global)
360-
end
361-
end
362-
394+
#--------------------------------------------------
395+
# At this point we've discovered all the bindings defined in this scope and
396+
# added them to `var_ids`.
397+
#
398+
# Next we record information about how the new bindings relate to the
399+
# enclosing lambda
400+
# * All non-globals are recorded (kind :local and :argument will later be turned into slots)
401+
# * Captured variables are detected and recorded
363402
lambda_bindings = is_outer_lambda_scope ? LambdaBindings() : parentscope.lambda_bindings
403+
364404
for id in values(var_ids)
365-
vk = var_kind(ctx, id)
366-
if vk === :local
367-
push!(lambda_bindings.locals, id)
405+
binfo = lookup_binding(ctx, id)
406+
if !binfo.is_ssa && binfo.kind !== :global
407+
init_lambda_binding(lambda_bindings, id)
368408
end
369409
end
410+
411+
# FIXME: This assumes used bindings are internal to the lambda and cannot
412+
# be from the environment, and also assumes they are assigned. That's
413+
# correct for now but in general we should go by the same code path that
414+
# identifiers do.
370415
for id in used_bindings
371-
info = lookup_binding(ctx, id)
372-
if !info.is_ssa && info.kind == :local
373-
push!(lambda_bindings.locals, id)
416+
binfo = lookup_binding(ctx, id)
417+
if !binfo.is_ssa && binfo.kind !== :global
418+
if !haskey(lambda_bindings.bindings, id)
419+
init_lambda_binding(lambda_bindings, id, is_read=true, is_assigned=true)
420+
end
421+
end
422+
end
423+
424+
for varkey in used
425+
id = haskey(var_ids, varkey) ? var_ids[varkey] : lookup_var(ctx, varkey)
426+
if id === nothing
427+
# Identifiers which are used but not defined in some scope are
428+
# newly discovered global bindings
429+
init_binding(ctx, varkey, :global)
430+
elseif !in_toplevel_thunk
431+
binfo = lookup_binding(ctx, id)
432+
if binfo.kind !== :global
433+
if !haskey(lambda_bindings.bindings, id)
434+
# Used vars from a scope *outside* the current lambda are captured
435+
init_lambda_binding(lambda_bindings, id, is_captured=true, is_read=true)
436+
else
437+
update_lambda_binding!(lambda_bindings, id, is_read=true)
438+
end
439+
end
440+
end
441+
end
442+
443+
if !in_toplevel_thunk
444+
for (varkey,_) in assignments
445+
id = haskey(var_ids, varkey) ? var_ids[varkey] : lookup_var(ctx, varkey)
446+
binfo = lookup_binding(ctx, id)
447+
if binfo.kind !== :global
448+
if !haskey(lambda_bindings.bindings, id)
449+
# Assigned vars from a scope *outside* the current lambda are captured
450+
init_lambda_binding(lambda_bindings, id, is_captured=true, is_assigned=true)
451+
else
452+
update_lambda_binding!(lambda_bindings, id, is_assigned=true)
453+
end
454+
end
374455
end
375456
end
376457

@@ -410,6 +491,14 @@ function maybe_update_bindings!(ctx, ex)
410491
throw(LoweringError(ex, "unsupported `const` declaration on local variable"))
411492
end
412493
update_binding!(ctx, id; is_const=true)
494+
elseif k == K"call"
495+
name = ex[1]
496+
if kind(name) == K"BindingId"
497+
id = name.var_id
498+
if haskey(last(ctx.scope_stack).lambda_bindings.bindings, id)
499+
update_lambda_binding!(ctx, id, is_called=true)
500+
end
501+
end
413502
end
414503
nothing
415504
end

0 commit comments

Comments
 (0)