Skip to content

Commit

Permalink
Compiler: turn proc pointer into proc literal in some cases (crystal-…
Browse files Browse the repository at this point in the history
…lang#9824)

* Compiler: turn proc pointer into proc literal in some cases

* Correctly detect closure using ProcPointer
  • Loading branch information
asterite committed Oct 23, 2020
1 parent 84aab66 commit 442952f
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 26 deletions.
34 changes: 34 additions & 0 deletions spec/compiler/codegen/proc_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -847,4 +847,38 @@ describe "Code gen: proc" do
end
))
end

it "closures var on ->var.call (#8584)" do
run(%(
def bar(x)
x
end
struct Foo
def initialize
@value = 1
end
def value
bar(@value)
@value
end
end
def get_proc_a
foo = Foo.new
->foo.value
end
def get_proc_b
foo = Foo.new
->{ foo.value }
end
proc_a = get_proc_a
proc_b = get_proc_b
proc_b.call
proc_a.call
)).to_i.should eq(1)
end
end
2 changes: 1 addition & 1 deletion spec/compiler/semantic/closure_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ describe "Semantic: closure" do
foo = Foo.new
LibC.foo(->foo.bar)
),
"can't send closure to C function (closured vars: self)"
"can't send closure to C function (closured vars: foo)"
end

it "errors if sending closured proc pointer to C (3)" do
Expand Down
9 changes: 1 addition & 8 deletions src/compiler/crystal/codegen/codegen.cr
Original file line number Diff line number Diff line change
Expand Up @@ -592,14 +592,7 @@ module Crystal

if obj = node.obj
accept obj

# If obj is a primitive like an integer we need to pass
# the variable as is (without loading it)
if obj.is_a?(Var) && obj.type.is_a?(PrimitiveType)
call_self = context.vars[obj.name].pointer
else
call_self = @last
end
call_self = @last
elsif owner.passed_as_self?
call_self = llvm_self
end
Expand Down
10 changes: 10 additions & 0 deletions src/compiler/crystal/semantic/ast.cr
Original file line number Diff line number Diff line change
Expand Up @@ -783,4 +783,14 @@ module Crystal
Unreachable.new
end
end

class ProcLiteral
# If this ProcLiteral was created from expanding a ProcPointer,
# this holds the reference to it.
property proc_pointer : ProcPointer?
end

class ProcPointer
property expanded : ASTNode?
end
end
4 changes: 4 additions & 0 deletions src/compiler/crystal/semantic/bindings.cr
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ module Crystal
property! call : Call

def map_type(type)
if self.expanded
return type
end

return nil unless call.type?

arg_types = call.args.map &.type.virtual_type
Expand Down
54 changes: 38 additions & 16 deletions src/compiler/crystal/semantic/cleanup_transformer.cr
Original file line number Diff line number Diff line change
Expand Up @@ -500,34 +500,56 @@ module Crystal

def check_args_are_not_closure(node, message)
node.args.each do |arg|
case arg
when ProcLiteral
if arg.def.closure?
vars = ClosuredVarsCollector.collect arg.def
unless vars.empty?
message += " (closured vars: #{vars.join ", "})"
end
check_arg_is_not_closure(node, message, arg)
end
end

arg.raise message
end
when ProcPointer
if arg.obj.try &.type?.try &.passed_as_self?
def check_arg_is_not_closure(node, message, arg)
case arg
when Expressions
arg.expressions.each do |exp|
check_arg_is_not_closure(node, message, exp)
end
when ProcLiteral
if proc_pointer = arg.proc_pointer
case proc_pointer.obj
when Var
arg.raise "#{message} (closured vars: #{proc_pointer.obj})"
when InstanceVar
arg.raise "#{message} (closured vars: self)"
end
return
end

owner = arg.call.target_def.owner
if owner.passed_as_self?
arg.raise "#{message} (closured vars: self)"
if arg.def.closure?
vars = ClosuredVarsCollector.collect arg.def
unless vars.empty?
message += " (closured vars: #{vars.join ", "})"
end
else
# nothing to do

arg.raise message
end
when ProcPointer
if arg.obj.try &.type?.try &.passed_as_self?
arg.raise "#{message} (closured vars: self)"
end

owner = arg.call.target_def.owner
if owner.passed_as_self?
arg.raise "#{message} (closured vars: self)"
end
else
# nothing to do
end
end

def transform(node : ProcPointer)
super

if expanded = node.expanded
return transform(expanded)
end

if call = node.call?
result = call.transform(self)

Expand Down
6 changes: 5 additions & 1 deletion src/compiler/crystal/semantic/fix_missing_types.cr
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ class Crystal::FixMissingTypes < Crystal::Visitor
end

def visit(node : ProcPointer)
node.call?.try &.accept self
if expanded = node.expanded
expanded.accept(self)
else
node.call?.try &.accept self
end
false
end

Expand Down
34 changes: 34 additions & 0 deletions src/compiler/crystal/semantic/literal_expander.cr
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,40 @@ module Crystal
end
end

# Expand this:
#
# ```
# ->foo.bar(X, Y)
# ```
#
# To this:
#
# ```
# tmp = foo
# ->(x : X, y : Y) { tmp.bar(x, y) }
# ```
def expand(node : ProcPointer)
obj = node.obj.not_nil!

temp_var = new_temp_var.at(obj)
assign = Assign.new(temp_var, obj)
obj = temp_var

def_args = node.args.map do |arg|
Arg.new(@program.new_temp_var_name, restriction: arg).at(arg)
end

call_args = def_args.map do |def_arg|
Var.new(def_arg.name).at(def_arg).as(ASTNode)
end

body = Call.new(obj, node.name, call_args).at(node)
proc_literal = ProcLiteral.new(Def.new("->", def_args, body)).at(node)
proc_literal.proc_pointer = node

Expressions.new([assign, proc_literal])
end

private def regex_new_call(node, value)
Call.new(Path.global("Regex").at(node), "new", value, regex_options(node)).at(node)
end
Expand Down
7 changes: 7 additions & 0 deletions src/compiler/crystal/semantic/main_visitor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,13 @@ module Crystal
def visit(node : ProcPointer)
obj = node.obj

# If it's something like `->foo.bar` we turn it into a closure
# where `foo` is assigned to a temporary variable.
if obj.is_a?(Var) || obj.is_a?(InstanceVar) || obj.is_a?(ClassVar)
expand(node)
return false
end

if obj
obj.accept self
end
Expand Down

0 comments on commit 442952f

Please sign in to comment.