Skip to content

Commit 44910bf

Browse files
authored
fix: apply init values after reduce (#896)
1 parent 63b384a commit 44910bf

File tree

3 files changed

+55
-102
lines changed

3 files changed

+55
-102
lines changed

src/Ops.jl

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ end
391391
end
392392

393393
# shape ops
394-
function reshape(x::TracedRArray, dims...; kwargs...)
394+
function reshape(x::TracedRArray, dims::Integer...; kwargs...)
395395
return reshape(x, collect(dims); kwargs...)
396396
end
397397

@@ -2392,9 +2392,9 @@ end
23922392
"""
23932393
reduce(
23942394
x::TracedRArray{T},
2395-
init_values::TracedRNumber{T},
2395+
init_values::Union{Nothing,TracedRNumber{T}},
23962396
dimensions::Vector{Int},
2397-
fn::Function,
2397+
fn::Function;
23982398
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
23992399
)
24002400
@@ -2433,18 +2433,36 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24332433
"""
24342434
@noinline function reduce(
24352435
x::TracedRArray{T},
2436-
init_values::TracedRNumber{T},
2436+
init_values::Union{TracedRNumber{T},Nothing},
24372437
dimensions::Vector{Int},
2438-
fn::Function,
2438+
fn::Function;
24392439
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
24402440
) where {T}
2441+
elT = T
2442+
if init_values === nothing
2443+
if fn === min || fn === Base.FastMath.min_fast
2444+
init = typemax(elT)
2445+
elseif fn === max || fn === Base.FastMath.max_fast
2446+
init = typemin(elT)
2447+
else
2448+
init = Base.reduce_empty(Base.BottomRF(fn), elT)
2449+
end
2450+
2451+
initT = unwrapped_eltype(typeof(init))
2452+
if initT != elT # Bool, etc. reductions
2453+
elT = promote_type(initT, elT)
2454+
x = elT.(x)
2455+
end
2456+
init_values = Reactant.TracedUtils.promote_to(TracedRNumber{elT}, init)
2457+
end
2458+
24412459
reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions))
24422460

2443-
result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)
2461+
result_type = mlir_type(TracedRArray{elT,length(reduced_shape)}, reduced_shape)
24442462

24452463
sample_inputs = [
2446-
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2447-
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2464+
Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0),
2465+
Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0),
24482466
]
24492467

24502468
func =
@@ -2458,14 +2476,8 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24582476
return_dialect=:stablehlo,
24592477
).f
24602478
@assert MLIR.IR.nregions(func) == 1
2461-
fn_name = String(
2462-
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
2463-
)
2464-
ftype_attr = MLIR.IR.attr(func, "function_type")
2465-
ftype = MLIR.IR.Type(ftype_attr)
2466-
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error (
2467-
"$fn return type is not tensor<i1>"
2468-
)
2479+
ftype = MLIR.IR.Type(MLIR.IR.attr(func, "function_type"))
2480+
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(elT)) "$fn return type is not tensor<i1>"
24692481
fn = MLIR.IR.Region()
24702482
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
24712483
MLIR.IR.rmfromparent!(func)
@@ -2483,7 +2495,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24832495
),
24842496
)
24852497

2486-
return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape)
2498+
return TracedRArray{elT,length(reduced_shape)}((), res, reduced_shape)
24872499
end
24882500

24892501
end # module Ops

src/TracedRArray.jl

Lines changed: 14 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -468,100 +468,29 @@ function Base.mapreduce(
468468
dims=:,
469469
init=nothing,
470470
) where {T,N}
471-
A = materialize_traced_array(A)
471+
inp = broadcast(f, materialize_traced_array(A))
472472

473-
if dims isa Int
474-
dims = [dims]
475-
end
476-
477-
op_in_T = Core.Compiler.return_type(f, Tuple{T})
473+
dims isa Number && (dims = (dims,))
478474

479-
if init === nothing
480-
if op === min
481-
init = typemax(op_in_T)
482-
elseif op === max
483-
init = typemin(op_in_T)
484-
else
485-
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
486-
end
487-
488-
if typeof(init) != op_in_T
489-
op_in_T = typeof(init)
490-
A = typeof(init).(A)
491-
end
475+
if init !== nothing && typeof(init) != unwrapped_eltype(inp)
476+
inp = typeof(init).(inp)
492477
end
493478

494-
init = [TracedUtils.broadcast_to_size(init, ()).mlir_data]
495-
496-
inp = [broadcast(f, A).mlir_data]
479+
rdims = dims == (:) ? collect(Int64, 1:N) : collect(Int64, dims)
497480

498-
rdims = Int64[]
481+
reduction_result = Ops.reduce(inp, nothing, rdims, op)
499482

500-
if dims == (:)
501-
for i in 0:(N - 1)
502-
push!(rdims, i)
503-
end
483+
reduction_result = if dims != (:)
484+
Ops.reshape(reduction_result, Int64[i rdims ? 1 : size(A, i) for i in 1:N])
504485
else
505-
for i in dims
506-
push!(rdims, i - 1)
507-
end
508-
end
509-
510-
in_tys = [
511-
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))),
512-
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))),
513-
]
514-
515-
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()])
516-
517-
args = (
518-
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)),
519-
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)),
520-
)
521-
522-
resty = MLIR.IR.block!(fnbody) do
523-
tmp = TracedUtils.broadcast_to_size(op(args...), ())
524-
Ops.return_(tmp)
525-
return eltype(MLIR.IR.type(tmp.mlir_data))
486+
TracedRNumber{unwrapped_eltype(reduction_result)}((), reduction_result.mlir_data)
526487
end
527488

528-
toonedims = Int[]
529-
outdims = Int[]
530-
for i in 1:N
531-
tmp = if in(i - 1, rdims)
532-
1
533-
else
534-
sz = size(A, i)
535-
push!(outdims, sz)
536-
sz
537-
end
538-
push!(toonedims, tmp)
539-
end
540-
541-
TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)]
542-
543-
body = MLIR.IR.Region()
544-
push!(body, fnbody)
545-
red = MLIR.Dialects.stablehlo.reduce(
546-
inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body
547-
)
548-
549-
red = MLIR.IR.result(red, 1)
550-
redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red)))
551-
552-
if dims != (:)
553-
red = Ops.reshape(TracedRArray(red), toonedims...)
554-
else
555-
if length(outdims) == 0
556-
red = TracedRNumber{redT}((), red)
557-
else
558-
red = TracedRArray{redT,length(outdims)}((), red, (outdims...,))
559-
end
560-
end
561-
return red
489+
init === nothing && return reduction_result
490+
return broadcast(op, reduction_result, init)
562491
end
563492

564-
function Base.mapreducedim!(
493+
function Base._mapreducedim!(
565494
@nospecialize(f),
566495
@nospecialize(op),
567496
@nospecialize(R::AnyTracedRArray),
@@ -573,9 +502,9 @@ function Base.mapreducedim!(
573502
@assert sR == 1
574503
return i
575504
end
505+
isempty(A) && return R
576506
tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims))
577-
# set_mlir_data!(R, get_mlir_data(tmp))
578-
R .= op.(R, tmp) # match native Julia's behavior
507+
R .= op.(R, tmp)
579508
return R
580509
end
581510

test/basic.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,18 @@ end
939939
)
940940
end
941941

942+
@testset "mapreduce with init" begin
943+
x = reshape(collect(Float32, 1:12), 3, 4)
944+
x_ra = Reactant.to_rarray(x)
945+
946+
init = 3.0
947+
init_ra = Reactant.to_rarray(init; track_numbers=Number)
948+
949+
fn(x, init; kwargs...) = sum(x; init, kwargs...)
950+
951+
@test @jit(fn(x_ra, init_ra; dims=2)) fn(x, init; dims=2)
952+
end
953+
942954
@testset "map!" begin
943955
x = randn(Float32, 2, 3)
944956
y = zeros(Float32, 2, 3)

0 commit comments

Comments
 (0)