Skip to content

Commit 6d5b0b5

Browse files
authored
Revert "fix: apply init values after reduce (#896)"
This reverts commit 44910bf.
1 parent 5897c74 commit 6d5b0b5

File tree

3 files changed

+102
-55
lines changed

3 files changed

+102
-55
lines changed

src/Ops.jl

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ end
391391
end
392392

393393
# shape ops
394-
function reshape(x::TracedRArray, dims::Integer...; kwargs...)
394+
function reshape(x::TracedRArray, dims...; kwargs...)
395395
return reshape(x, collect(dims); kwargs...)
396396
end
397397

@@ -2392,9 +2392,9 @@ end
23922392
"""
23932393
reduce(
23942394
x::TracedRArray{T},
2395-
init_values::Union{Nothing,TracedRNumber{T}},
2395+
init_values::TracedRNumber{T},
23962396
dimensions::Vector{Int},
2397-
fn::Function;
2397+
fn::Function,
23982398
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
23992399
)
24002400
@@ -2433,36 +2433,18 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24332433
"""
24342434
@noinline function reduce(
24352435
x::TracedRArray{T},
2436-
init_values::Union{TracedRNumber{T},Nothing},
2436+
init_values::TracedRNumber{T},
24372437
dimensions::Vector{Int},
2438-
fn::Function;
2438+
fn::Function,
24392439
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
24402440
) where {T}
2441-
elT = T
2442-
if init_values === nothing
2443-
if fn === min || fn === Base.FastMath.min_fast
2444-
init = typemax(elT)
2445-
elseif fn === max || fn === Base.FastMath.max_fast
2446-
init = typemin(elT)
2447-
else
2448-
init = Base.reduce_empty(Base.BottomRF(fn), elT)
2449-
end
2450-
2451-
initT = unwrapped_eltype(typeof(init))
2452-
if initT != elT # Bool, etc. reductions
2453-
elT = promote_type(initT, elT)
2454-
x = elT.(x)
2455-
end
2456-
init_values = Reactant.TracedUtils.promote_to(TracedRNumber{elT}, init)
2457-
end
2458-
24592441
reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions))
24602442

2461-
result_type = mlir_type(TracedRArray{elT,length(reduced_shape)}, reduced_shape)
2443+
result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)
24622444

24632445
sample_inputs = [
2464-
Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0),
2465-
Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0),
2446+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2447+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
24662448
]
24672449

24682450
func =
@@ -2476,8 +2458,14 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24762458
return_dialect=:stablehlo,
24772459
).f
24782460
@assert MLIR.IR.nregions(func) == 1
2479-
ftype = MLIR.IR.Type(MLIR.IR.attr(func, "function_type"))
2480-
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(elT)) "$fn return type is not tensor<i1>"
2461+
fn_name = String(
2462+
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
2463+
)
2464+
ftype_attr = MLIR.IR.attr(func, "function_type")
2465+
ftype = MLIR.IR.Type(ftype_attr)
2466+
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error (
2467+
"$fn return type is not tensor<i1>"
2468+
)
24812469
fn = MLIR.IR.Region()
24822470
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
24832471
MLIR.IR.rmfromparent!(func)
@@ -2495,7 +2483,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24952483
),
24962484
)
24972485

2498-
return TracedRArray{elT,length(reduced_shape)}((), res, reduced_shape)
2486+
return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape)
24992487
end
25002488

25012489
end # module Ops

src/TracedRArray.jl

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -468,29 +468,100 @@ function Base.mapreduce(
468468
dims=:,
469469
init=nothing,
470470
) where {T,N}
471-
inp = broadcast(f, materialize_traced_array(A))
471+
A = materialize_traced_array(A)
472472

473-
dims isa Number && (dims = (dims,))
473+
if dims isa Int
474+
dims = [dims]
475+
end
476+
477+
op_in_T = Core.Compiler.return_type(f, Tuple{T})
474478

475-
if init !== nothing && typeof(init) != unwrapped_eltype(inp)
476-
inp = typeof(init).(inp)
479+
if init === nothing
480+
if op === min
481+
init = typemax(op_in_T)
482+
elseif op === max
483+
init = typemin(op_in_T)
484+
else
485+
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
486+
end
487+
488+
if typeof(init) != op_in_T
489+
op_in_T = typeof(init)
490+
A = typeof(init).(A)
491+
end
477492
end
478493

479-
rdims = dims == (:) ? collect(Int64, 1:N) : collect(Int64, dims)
494+
init = [TracedUtils.broadcast_to_size(init, ()).mlir_data]
495+
496+
inp = [broadcast(f, A).mlir_data]
480497

481-
reduction_result = Ops.reduce(inp, nothing, rdims, op)
498+
rdims = Int64[]
482499

483-
reduction_result = if dims != (:)
484-
Ops.reshape(reduction_result, Int64[i rdims ? 1 : size(A, i) for i in 1:N])
500+
if dims == (:)
501+
for i in 0:(N - 1)
502+
push!(rdims, i)
503+
end
485504
else
486-
TracedRNumber{unwrapped_eltype(reduction_result)}((), reduction_result.mlir_data)
505+
for i in dims
506+
push!(rdims, i - 1)
507+
end
508+
end
509+
510+
in_tys = [
511+
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))),
512+
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))),
513+
]
514+
515+
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()])
516+
517+
args = (
518+
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)),
519+
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)),
520+
)
521+
522+
resty = MLIR.IR.block!(fnbody) do
523+
tmp = TracedUtils.broadcast_to_size(op(args...), ())
524+
Ops.return_(tmp)
525+
return eltype(MLIR.IR.type(tmp.mlir_data))
487526
end
488527

489-
init === nothing && return reduction_result
490-
return broadcast(op, reduction_result, init)
528+
toonedims = Int[]
529+
outdims = Int[]
530+
for i in 1:N
531+
tmp = if in(i - 1, rdims)
532+
1
533+
else
534+
sz = size(A, i)
535+
push!(outdims, sz)
536+
sz
537+
end
538+
push!(toonedims, tmp)
539+
end
540+
541+
TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)]
542+
543+
body = MLIR.IR.Region()
544+
push!(body, fnbody)
545+
red = MLIR.Dialects.stablehlo.reduce(
546+
inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body
547+
)
548+
549+
red = MLIR.IR.result(red, 1)
550+
redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red)))
551+
552+
if dims != (:)
553+
red = Ops.reshape(TracedRArray(red), toonedims...)
554+
else
555+
if length(outdims) == 0
556+
red = TracedRNumber{redT}((), red)
557+
else
558+
red = TracedRArray{redT,length(outdims)}((), red, (outdims...,))
559+
end
560+
end
561+
return red
491562
end
492563

493-
function Base._mapreducedim!(
564+
function Base.mapreducedim!(
494565
@nospecialize(f),
495566
@nospecialize(op),
496567
@nospecialize(R::AnyTracedRArray),
@@ -502,9 +573,9 @@ function Base._mapreducedim!(
502573
@assert sR == 1
503574
return i
504575
end
505-
isempty(A) && return R
506576
tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims))
507-
R .= op.(R, tmp)
577+
# set_mlir_data!(R, get_mlir_data(tmp))
578+
R .= op.(R, tmp) # match native Julia's behavior
508579
return R
509580
end
510581

test/basic.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -939,18 +939,6 @@ end
939939
)
940940
end
941941

942-
@testset "mapreduce with init" begin
943-
x = reshape(collect(Float32, 1:12), 3, 4)
944-
x_ra = Reactant.to_rarray(x)
945-
946-
init = 3.0
947-
init_ra = Reactant.to_rarray(init; track_numbers=Number)
948-
949-
fn(x, init; kwargs...) = sum(x; init, kwargs...)
950-
951-
@test @jit(fn(x_ra, init_ra; dims=2)) fn(x, init; dims=2)
952-
end
953-
954942
@testset "map!" begin
955943
x = randn(Float32, 2, 3)
956944
y = zeros(Float32, 2, 3)

0 commit comments

Comments
 (0)