Skip to content

Commit 4dbe9b1

Browse files
authored
update for 1.10 (#205)
1 parent fae23e7 commit 4dbe9b1

File tree

2 files changed

+79
-18
lines changed

2 files changed

+79
-18
lines changed

src/overdub.jl

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,11 @@ function reflect(@nospecialize(sigtypes::Tuple), world::UInt = get_world_counter
118118
method_instance === nothing && return nothing
119119
method_signature = method.sig
120120
static_params = Any[raw_static_params...]
121-
code_info = Core.Compiler.retrieve_code_info(method_instance)
121+
@static if VERSION >= v"1.10.0-DEV.873"
122+
code_info = Core.Compiler.retrieve_code_info(method_instance, world)
123+
else
124+
code_info = Core.Compiler.retrieve_code_info(method_instance)
125+
end
122126
isa(code_info, CodeInfo) || return nothing
123127
code_info = copy_code_info(code_info)
124128
verbose_lineinfo!(code_info, S)
@@ -598,13 +602,37 @@ const OVERDUB_FALLBACK = begin
598602
end
599603

600604
# `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)`
601-
function __overdub_generator__(self, context_type, args::Tuple)
605+
function __overdub_generator__(world::UInt, source, self, context_type, args)
606+
if nfields(args) > 0
607+
is_builtin = args[1] <: Core.Builtin
608+
is_invoke = args[1] === typeof(Core.invoke)
609+
if !is_builtin || is_invoke
610+
try
611+
untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args))
612+
reflection = reflect(untagged_args, world)
613+
if isa(reflection, Reflection)
614+
result = overdub_pass!(reflection, context_type, is_invoke)
615+
isa(result, Expr) && return result
616+
return reflection.code_info
617+
end
618+
catch err
619+
errmsg = "ERROR COMPILING $args IN CONTEXT $(context_type): \n" #* sprint(showerror, err)
620+
errmsg *= "\n" .* repr("text/plain", stacktrace(catch_backtrace()))
621+
return quote
622+
error($errmsg)
623+
end
624+
end
625+
end
626+
end
627+
return copy_code_info(OVERDUB_FALLBACK)
628+
end
629+
function __overdub_generator__(self, context_type, args)
602630
if nfields(args) > 0
603631
is_builtin = args[1] <: Core.Builtin
604632
is_invoke = args[1] === typeof(Core.invoke)
605633
if !is_builtin || is_invoke
606634
try
607-
untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,)
635+
untagged_args = ntuple(i->untagtype(args[i], context_type), nfields(args))
608636
reflection = reflect(untagged_args)
609637
if isa(reflection, Reflection)
610638
result = overdub_pass!(reflection, context_type, is_invoke)
@@ -638,6 +666,18 @@ if VERSION >= v"1.4.0-DEV.304"
638666
end
639667

640668
let line = @__LINE__, file = @__FILE__
669+
@static if VERSION >= v"1.10.0-DEV.873"
670+
@eval (@__MODULE__) begin
671+
function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
672+
$(Expr(:meta, :generated_only))
673+
$(Expr(:meta, :generated, __overdub_generator__))
674+
end
675+
function $Cassette.recurse($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
676+
$(Expr(:meta, :generated_only))
677+
$(Expr(:meta, :generated, __overdub_generator__))
678+
end
679+
end
680+
else
641681
@eval (@__MODULE__) begin
642682
function $Cassette.overdub($OVERDUB_CONTEXT_NAME::$Cassette.Context, $OVERDUB_ARGUMENTS_NAME...)
643683
$(Expr(:meta, :generated_only))
@@ -666,6 +706,7 @@ let line = @__LINE__, file = @__FILE__
666706
true)))
667707
end
668708
end
709+
end
669710
end
670711

671712
@doc """

test/misctests.jl

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ function rosenbrock(x::Vector{Float64})
1717
end
1818

1919
x = rand(2)
20-
@inferred(overdub(RosCtx(), rosenbrock, x))
20+
if VERSION < v"1.9"
21+
@inferred (overdub(RosCtx(), rosenbrock, x))
22+
end
2123

2224
messages = String[]
2325
Cassette.prehook(::RosCtx, f, args...) = push!(messages, string("calling ", f, args))
@@ -79,16 +81,29 @@ empty!(pres)
7981
empty!(posts)
8082

8183
@overdub(ctx, Core._apply(+, (x1, x2), (x2 * x3, x3)))
82-
@test pres == [(tuple, (x1, x2)),
83-
(*, (x2, x3)),
84-
(Base.mul_int, (x2, x3)),
85-
(tuple, (x2*x3, x3)),
86-
(+, (x1, x2, x2*x3, x3))]
87-
@test posts == [((x1, x2), tuple, (x1, x2)),
88-
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
89-
(*(x2, x3), *, (x2, x3)),
90-
((x2*x3, x3), tuple, (x2*x3, x3)),
91-
(+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))]
84+
if !(v"1.9" <= VERSION < v"1.10")
85+
@test pres == [(tuple, (x1, x2)),
86+
(*, (x2, x3)),
87+
(Base.mul_int, (x2, x3)),
88+
(tuple, (x2*x3, x3)),
89+
(+, (x1, x2, x2*x3, x3))]
90+
@test posts == [((x1, x2), tuple, (x1, x2)),
91+
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
92+
(*(x2, x3), *, (x2, x3)),
93+
((x2*x3, x3), tuple, (x2*x3, x3)),
94+
(+(x1, x2, x2*x3, x3), +, (x1, x2, x2*x3, x3))]
95+
else
96+
@test pres == [(tuple, (x1, x2)),
97+
(*, (x2, x3)),
98+
(Base.mul_int, (x2, x3)),
99+
(tuple, (x2*x3, x3)),
100+
(Core._apply, (+, (x1, x2), (x2*x3, x3)))]
101+
@test posts == [((x1, x2), tuple, (x1, x2)),
102+
(Base.mul_int(x2, x3), Base.mul_int, (x2, x3)),
103+
(*(x2, x3), *, (x2, x3)),
104+
((x2*x3, x3), tuple, (x2*x3, x3)),
105+
(+(x1, x2, x2*x3, x3), Core._apply, (+, (x1, x2), (x2*x3, x3)))]
106+
end
92107

93108
println("done (took ", time() - before_time, " seconds)")
94109

@@ -386,7 +401,10 @@ else
386401
@inferred(overdub(InferCtx(), rand, Float32, 1))
387402
end
388403
end
389-
@inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1)))
404+
405+
if VERSION < v"1.9"
406+
@inferred(overdub(InferCtx(), broadcast, +, rand(1), rand(1)))
407+
end
390408
@inferred(overdub(InferCtx(), () -> kwargtest(42; foo = 1, bar = 2)))
391409

392410
println("done (took ", time() - before_time, " seconds)")
@@ -427,9 +445,11 @@ ctx = InvokeCtx(metadata=Any[])
427445
@test overdub(ctx, invoker, 3) === 9
428446
# This is kind of fragile and may break for unrelated reasons - the main thing
429447
# we're testing here is that we properly trace through the `invoke` call.
430-
@test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type,
431-
Val{2}, Core.apply_type, Base.literal_pow, *,
432-
Base.mul_int]
448+
if VERSION < v"1.9"
449+
@test ctx.metadata == Any[Core.apply_type, Core.invoke, Core.apply_type,
450+
Val{2}, Core.apply_type, Base.literal_pow, *,
451+
Base.mul_int]
452+
end
433453

434454
println("done (took ", time() - before_time, " seconds)")
435455

0 commit comments

Comments
 (0)