Skip to content

Insert call instructions (without any caching) #1276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft

Conversation

jumerckx
Copy link
Collaborator

No description provided.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

JuliaFormatter

[JuliaFormatter] reported by reviewdog 🐶

Reactant.jl/src/utils.jl

Lines 1105 to 1113 in 68e7079

traced_result = push_inst!(Expr(
:call,
GlobalRef(Core, :_apply_iterate),
Base.iterate,
call_epilogue,
push_inst!(Expr(:call, GlobalRef(Core, :tuple), fn_args[1], push_inst!(Expr(:call, GlobalRef(Core, :tuple), fn_args[2:end]...)))),
push_inst!(Expr(:call, GlobalRef(Core, :tuple), traced_result)),
push_inst!(Expr(:call, GlobalRef(Reactant, :get_args_from_finalize_function), finalize_function_result)),
))


[JuliaFormatter] reported by reviewdog 🐶

Reactant.jl/src/utils.jl

Lines 1117 to 1122 in 68e7079

push_inst!(Expr(
:call,
GlobalRef(Base, :setindex!),
GlobalRef(Reactant, :TRACE_CALLS),
should_trace_call,
))


[JuliaFormatter] reported by reviewdog 🐶

Reactant.jl/src/utils.jl

Lines 1125 to 1130 in 68e7079

push_inst!(Expr(
:call,
GlobalRef(Base, :setindex!),
GlobalRef(Reactant, :TRACE_CALLS),
TRACE_CALLS[],
))

@@ -398,7 +398,7 @@ Base.@nospecializeinfer function traced_type_inner(
}
end
error("Unsupported runtime $runtime")
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath || mode == TracedToTypes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath || mode == TracedToTypes
elseif mode == TracedTrack ||
mode == NoStopTracedTrack ||
mode == TracedSetPath ||
mode == TracedToTypes

@@ -444,7 +444,7 @@ Base.@nospecializeinfer function traced_type_inner(
}
end
error("Unsupported runtime $runtime")
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath || mode == TracedToTypes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath || mode == TracedToTypes
elseif mode == TracedTrack ||
mode == NoStopTracedTrack ||
mode == TracedSetPath ||
mode == TracedToTypes

kwargs...,
)
if mode == TracedToTypes
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return
return nothing

return nothing
end
if mode == TracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) && TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
!isnothing(path) && TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))

if !haskey(seen, prev)
return seen[prev] = prev
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) && TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
!isnothing(path) && TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))

if !haskey(seen, prev)
return seen[prev] = prev
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) && TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
!isnothing(path) && TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))

if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) && TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
!isnothing(path) && TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))

pr.traced_args_to_shardings,
pr.sym_visibility,
pr.args,
pr.N
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
pr.N
pr.N,


function deactivate_fnbody!(fnbody)
MLIR.IR.deactivate!(fnbody)
Ops.deactivate_constant_context!(fnbody)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Ops.deactivate_constant_context!(fnbody)
return Ops.deactivate_constant_context!(fnbody)

Ops.deactivate_constant_context!(fnbody)
end

function call_prologue(f, args, )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function call_prologue(f, args, )
function call_prologue(f, args)

Comment on lines +583 to +597
fnbody
) = result = TracedUtils.prepare_mlir_fn_args(
args,
f_name,
concretein,
true, # mutate_args
toscalar,
argprefix,
runtime,
optimize_then_pad,
do_transpose,
input_shardings,
verify_arg_names
)
mlir_caller_args = Reactant.MLIR.IR.Value[TracedUtils.get_mlir_data(x) for x in linear_args]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
fnbody
) = result = TracedUtils.prepare_mlir_fn_args(
args,
f_name,
concretein,
true, # mutate_args
toscalar,
argprefix,
runtime,
optimize_then_pad,
do_transpose,
input_shardings,
verify_arg_names
)
mlir_caller_args = Reactant.MLIR.IR.Value[TracedUtils.get_mlir_data(x) for x in linear_args]
fnbody,
) =
result = TracedUtils.prepare_mlir_fn_args(
args,
f_name,
concretein,
true, # mutate_args
toscalar,
argprefix,
runtime,
optimize_then_pad,
do_transpose,
input_shardings,
verify_arg_names,
)
mlir_caller_args = Reactant.MLIR.IR.Value[
TracedUtils.get_mlir_data(x) for x in linear_args
]

return result
end

function finalize_function(result, traced_args, linear_args, mlir_caller_args, seen_args, fnbody, func, mod, name, in_tys, inv_map, argprefix, traced_args_to_shardings, sym_visibility, args, N)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function finalize_function(result, traced_args, linear_args, mlir_caller_args, seen_args, fnbody, func, mod, name, in_tys, inv_map, argprefix, traced_args_to_shardings, sym_visibility, args, N)
function finalize_function(
result,
traced_args,
linear_args,
mlir_caller_args,
seen_args,
fnbody,
func,
mod,
name,
in_tys,
inv_map,
argprefix,
traced_args_to_shardings,
sym_visibility,
args,
N,
)

construct_function_without_args = false
output_shardings = nothing


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

Comment on lines +636 to +650

(;
func2,
f_name,
traced_result,
ret,
linear_args,
in_tys,
linear_results,
num_partitions,
is_sharded,
unique_meshes,
mutated_args,
global_device_ids
) = TracedUtils.finalize_mlir_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
(;
func2,
f_name,
traced_result,
ret,
linear_args,
in_tys,
linear_results,
num_partitions,
is_sharded,
unique_meshes,
mutated_args,
global_device_ids
) = TracedUtils.finalize_mlir_fn(
(; func2, f_name, traced_result, ret, linear_args, in_tys, linear_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = TracedUtils.finalize_mlir_fn(

args,
N,
concretein,
toscalar
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
toscalar
toscalar,

Comment on lines +703 to +704
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
)
)


end

function call_epilogue(f, args, traced_result, linear_args, f_name, ret, linear_results, mlir_caller_args, argprefix, resprefix, resargprefix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function call_epilogue(f, args, traced_result, linear_args, f_name, ret, linear_results, mlir_caller_args, argprefix, resprefix, resargprefix)
function call_epilogue(
f,
args,
traced_result,
linear_args,
f_name,
ret,
linear_results,
mlir_caller_args,
argprefix,
resprefix,
resargprefix,
)

callee=MLIR.IR.FlatSymbolRefAttribute(f_name),
)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

else
Core.println("Not tracing call to $fn.")
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@jumerckx
Copy link
Collaborator Author

jumerckx commented May 14, 2025

Status:

Test Summary:           | Pass  Fail  Error  Broken  Total      Time
Reactant.jl Tests       | 1480     5     27       1   1513  45m55.8s

(https://github.com/EnzymeAD/Reactant.jl/actions/runs/15022467542/job/42214855196?pr=1276#step:10:6661)

@@ -131,6 +131,11 @@ function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
return x
end

function set_mlir_data!(x::T, data) where T
@warn "Setting mlir data on a $T is a no-op."
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return x
function set_mlir_data!(x::T, data) where {T}

jumerckx added 2 commits May 17, 2025 22:20
…en calling `make_tracer` on the `AnyTracedRArray`.
Comment on lines +851 to +855
Core.println("Found method from module $(method.module) with name $(method.name), TRACE_CALLS[] = $(TRACE_CALLS[])")
trace_call_within = TRACE_CALLS[] && !(
has_ancestor(method.module, Reactant.TracedRNumberOverrides) ||
has_ancestor(method.module, Reactant.TracedRArrayOverrides) ||
has_ancestor(method.module, Core)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Core.println("Found method from module $(method.module) with name $(method.name), TRACE_CALLS[] = $(TRACE_CALLS[])")
trace_call_within = TRACE_CALLS[] && !(
has_ancestor(method.module, Reactant.TracedRNumberOverrides) ||
has_ancestor(method.module, Reactant.TracedRArrayOverrides) ||
has_ancestor(method.module, Core)
Core.println(
"Found method from module $(method.module) with name $(method.name), TRACE_CALLS[] = $(TRACE_CALLS[])",
)
trace_call_within =
TRACE_CALLS[] && !(
has_ancestor(method.module, Reactant.TracedRNumberOverrides) ||
has_ancestor(method.module, Reactant.TracedRArrayOverrides) ||
has_ancestor(method.module, Core)

has_ancestor(method.module, Reactant.TracedRArrayOverrides) ||
has_ancestor(method.module, Core)
)
if TRACE_CALLS[] && !(!(fn <: Function) || sizeof(fn) != 0 || fn <: Base.BroadcastFunction)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if TRACE_CALLS[] && !(!(fn <: Function) || sizeof(fn) != 0 || fn <: Base.BroadcastFunction)
if TRACE_CALLS[] &&
!(!(fn <: Function) || sizeof(fn) != 0 || fn <: Base.BroadcastFunction)

Comment on lines +1047 to +1050
ocres = if TRACE_CALLS[] && !(!(fn <: Function) || sizeof(fn) != 0 || fn <: Base.BroadcastFunction)
push!(code_info.slotnames, :tryfinallystate)
push!(code_info.slotflags, zero(UInt8))
tryfinally_slot = Core.SlotNumber(length(code_info.slotnames))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
ocres = if TRACE_CALLS[] && !(!(fn <: Function) || sizeof(fn) != 0 || fn <: Base.BroadcastFunction)
push!(code_info.slotnames, :tryfinallystate)
push!(code_info.slotflags, zero(UInt8))
tryfinally_slot = Core.SlotNumber(length(code_info.slotnames))
ocres =
if TRACE_CALLS[] &&
!(!(fn <: Function) || sizeof(fn) != 0 || fn <: Base.BroadcastFunction)
push!(code_info.slotnames, :tryfinallystate)
push!(code_info.slotflags, zero(UInt8))
tryfinally_slot = Core.SlotNumber(length(code_info.slotnames))
push!(code_info.slotnames, :ocres)
push!(code_info.slotflags, zero(UInt8))
ocres_slot = Core.SlotNumber(length(code_info.slotnames))
push_inst!(Core.NewvarNode(ocres_slot))
cached_or_nothing = push_inst!(
Expr(:call, get_cache, fn_args[1], fn_args[2:end]...)
)
is_not_cached = push_inst!(
Expr(:call, GlobalRef(Base, :isnothing), cached_or_nothing)
)
# TODO: conditional jump to cached block
# cached_dest = 0
# push_inst!(Core.GotoIfNot(is_not_cached, cached_dest))

Comment on lines +1052 to +1055
push!(code_info.slotnames, :ocres)
push!(code_info.slotflags, zero(UInt8))
ocres_slot = Core.SlotNumber(length(code_info.slotnames))
push_inst!(Core.NewvarNode(ocres_slot))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
push!(code_info.slotnames, :ocres)
push!(code_info.slotflags, zero(UInt8))
ocres_slot = Core.SlotNumber(length(code_info.slotnames))
push_inst!(Core.NewvarNode(ocres_slot))
prologue_result = push_inst!(
Expr(
:call,
GlobalRef(Reactant, :call_prologue),
fn_args[1],
push_inst!(Expr(:call, GlobalRef(Core, :tuple), fn_args[2:end]...)),
),
)

ocres_slot = Core.SlotNumber(length(code_info.slotnames))
push_inst!(Core.NewvarNode(ocres_slot))


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
catch_dest = length(overdubbed_code) + 12
enter = push_inst!(@static if VERSION < v"1.11"
Expr(:enter, catch_dest)
else
Core.EnterNode(catch_dest)
end)

Comment on lines +1155 to +1162
finalize_function_result = push_inst!(Expr(
:call,
GlobalRef(Core, :_apply_iterate),
Base.iterate,
GlobalRef(Reactant, :finalize_function),
push_inst!(Expr(:call, GlobalRef(Core, :tuple), traced_result)),
push_inst!(Expr(:call, GlobalRef(Reactant, :get_args_for), QuoteNode(:finalize_function), prologue_result))
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
finalize_function_result = push_inst!(Expr(
:call,
GlobalRef(Core, :_apply_iterate),
Base.iterate,
GlobalRef(Reactant, :finalize_function),
push_inst!(Expr(:call, GlobalRef(Core, :tuple), traced_result)),
push_inst!(Expr(:call, GlobalRef(Reactant, :get_args_for), QuoteNode(:finalize_function), prologue_result))
))
finalize_function_result = push_inst!(
Expr(
:call,
GlobalRef(Core, :_apply_iterate),
Base.iterate,
GlobalRef(Reactant, :finalize_function),
push_inst!(Expr(:call, GlobalRef(Core, :tuple), traced_result)),
push_inst!(
Expr(
:call,
GlobalRef(Reactant, :get_args_for),
QuoteNode(:finalize_function),
prologue_result,
),
),
),
)

push_inst!(Expr(:call, GlobalRef(Reactant, :get_args_for), QuoteNode(:finalize_function), prologue_result))
))

# TODO: save cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
# TODO: save cache
# TODO: save cache


# TODO: save cache

# TODO: unconditional jump over cached block.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
# TODO: unconditional jump over cached block.
# TODO: unconditional jump over cached block.


# TODO: unconditional jump over cached block.

# TODO: cached block
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
# TODO: cached block
# TODO: cached block

Comment on lines +1170 to +1184
# TODO: common final handling
traced_result = push_inst!(Expr(
:call,
GlobalRef(Core, :_apply_iterate),
Base.iterate,
call_epilogue,
push_inst!(Expr(:call, GlobalRef(Core, :tuple), fn_args[1], push_inst!(Expr(:call, GlobalRef(Core, :tuple), fn_args[2:end]...)))),
push_inst!(Expr(:call, GlobalRef(Core, :tuple), traced_result)),
push_inst!(Expr(:call, GlobalRef(Reactant, :get_args_from_finalize_function), finalize_function_result)),
))
traced_result
else
traced_result = push_inst!(Expr(:call, oc, fn_args[2:end]...))
traced_result
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
# TODO: common final handling
traced_result = push_inst!(Expr(
:call,
GlobalRef(Core, :_apply_iterate),
Base.iterate,
call_epilogue,
push_inst!(Expr(:call, GlobalRef(Core, :tuple), fn_args[1], push_inst!(Expr(:call, GlobalRef(Core, :tuple), fn_args[2:end]...)))),
push_inst!(Expr(:call, GlobalRef(Core, :tuple), traced_result)),
push_inst!(Expr(:call, GlobalRef(Reactant, :get_args_from_finalize_function), finalize_function_result)),
))
traced_result
else
traced_result = push_inst!(Expr(:call, oc, fn_args[2:end]...))
traced_result
end
# TODO: common final handling
traced_result = push_inst!(
Expr(
:call,
GlobalRef(Core, :_apply_iterate),
Base.iterate,
call_epilogue,
push_inst!(
Expr(
:call,
GlobalRef(Core, :tuple),
fn_args[1],
push_inst!(
Expr(:call, GlobalRef(Core, :tuple), fn_args[2:end]...)
),
),
),
push_inst!(Expr(:call, GlobalRef(Core, :tuple), traced_result)),
push_inst!(
Expr(
:call,
GlobalRef(Reactant, :get_args_from_finalize_function),
finalize_function_result,
),
),
),
)
traced_result
else
traced_result = push_inst!(Expr(:call, oc, fn_args[2:end]...))
traced_result
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant