Skip to content

feat: more coverage for NNlib functions #258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ LinearAlgebra = "1.10"
NNlib = "0.9"
OrderedCollections = "1"
Preferences = "1.4"
ReactantCore = "0.1"
ReactantCore = "0.1.1"
Reactant_jll = "0.0.24"
Scratch = "1.2"
Statistics = "1.10"
Expand Down
126 changes: 73 additions & 53 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ReactantNNlibExt
using NNlib
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
using ReactantCore: @trace
using LinearAlgebra: LinearAlgebra, triu

for (jlop, hloop) in (
Expand All @@ -20,38 +21,46 @@ for (jlop, hloop) in (
end
end

# TODO handle non finite cases
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
max_ = NNlib.fast_maximum(x; dims)
#if all(isfinite, max_)
@fastmath out .= exp.(x .- max_)
#else
# _zero, _one, _inf = T(0), T(1), T(Inf)
# @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
#end
# XXX: Once reverse mode of if is properly supported, we can make it @trace
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
# one_num = Reactant.promote_to(TracedRNumber{T}, 1)
# @trace if all(isfinite, max_)
@. out = exp(x - max_)
# else
# cond = max_ .== Inf
# true_pred = ifelse.(x .== Inf, one_num, zero_num)
# @. out = ifelse(cond, true_pred, exp(x - max_))
# end
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
return out ./= tmp
out ./= tmp
return out
end

function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
max_ = NNlib.fast_maximum(x; dims)
# if all(isfinite, max_)
@fastmath out .= x .- max_
# XXX: Once reverse mode of if is properly supported, we can make it @trace
# inf_num = Reactant.promote_to(TracedRNumber{T}, Inf)
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
# @trace if all(isfinite, max_)
@. out = x - max_
# else
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
# @. out = ifelse(
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
# )
# cond = max_ .== Inf
# true_pred = ifelse.(x .== Inf, zero_num, -inf_num)
# @. out = ifelse(cond, true_pred, x - max_)
# end
@fastmath log_ = log.(sum(exp, out; dims))
return out .-= log_
out .-= log_
return out
end

function NNlib.conv(
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
function NNlib.conv!(
y::TracedRArray{T,N}, x::AnyTracedRArray, W::AnyTracedRArray, cdims::DenseConvDims
) where {T,N}
x = materialize_traced_array(x)
W = materialize_traced_array(W)
# StableHLO expects matching element types
x = T.(materialize_traced_array(x))
W = T.(materialize_traced_array(W))

kernel_size = NNlib.kernel_size(cdims)
padding = NNlib.padding(cdims)
Expand All @@ -77,33 +86,31 @@ function NNlib.conv(
pl, pr = padding[2i - 1], padding[2i]
d = dilation[i]
s = stride[i]

(size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
end
output_batch_dim = input_batch_dim
output_feature_dim = input_feature_dim
output_spatial_dims = input_spatial_dims

output_shape = (output_spatial_shapes..., size(W, kernel_output_dim), size(x, N))

dimension_numbers = """
#stablehlo.conv<raw
input_batch_dimension = $(input_batch_dim - 1),
input_feature_dimension = $(input_feature_dim - 1),
input_spatial_dimensions = [$(join(input_spatial_dims .- 1, ", "))],
kernel_output_feature_dimension = $(kernel_output_dim - 1),
kernel_input_feature_dimension = $(kernel_input_dim - 1),
kernel_spatial_dimensions = [$(join(kernel_spatial_dims .- 1, ", "))],
output_batch_dimension = $( output_batch_dim - 1 ),
output_feature_dimension = $( output_feature_dim - 1),
output_spatial_dimensions = [$(join(output_spatial_dims .- 1, ", "))],
>"""
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)
#! format: off
dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
MLIR.IR.context(),
Int64(input_batch_dim - 1),
Int64(input_feature_dim - 1),
length(input_spatial_dims), Int64[i - 1 for i in input_spatial_dims],
Int64(kernel_input_dim - 1),
Int64(kernel_output_dim - 1),
length(kernel_spatial_dims), Int64[i - 1 for i in kernel_spatial_dims],
Int64(output_batch_dim - 1),
Int64(output_feature_dim - 1),
length(output_spatial_dims), Int64[i - 1 for i in output_spatial_dims],
)
#! format: on

padding = Reactant.MLIR.IR.DenseElementsAttribute(
reshape(collect(padding), (num_spatial_dims, 2))
)
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))

weight = W.mlir_data
if !flipkernel
Expand All @@ -126,8 +133,8 @@ function NNlib.conv(
feature_group_count,
batch_group_count=1,
)

return TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
y.mlir_data = Reactant.MLIR.IR.result(conv)
return y
end

function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
Expand Down Expand Up @@ -198,27 +205,39 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(reduction), size(result_type))
end

function NNlib.maxpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
return reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, x, pdims; init=typemin(T)
)
function NNlib.maxpool!(
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
) where {T}
y.mlir_data =
reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
).mlir_data
return y
end

function NNlib.meanpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
numel = prod(NNlib.kernel_size(pdims))
return reduce_window(Reactant.MLIR.Dialects.stablehlo.add, x, pdims; init=zero(T)) ./
T(numel)
function NNlib.meanpool!(
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
) where {T}
res = reduce_window(Reactant.MLIR.Dialects.stablehlo.add, T.(x), pdims; init=zero(T))
y.mlir_data = (res ./ T(prod(NNlib.kernel_size(pdims)))).mlir_data
return y
end

NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x)
function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T}
y = permutedims(x, (2, 1, 3))
conj!(y)
return y
end

function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) where {T}
function NNlib.batched_mul!(
res::TracedRArray{T1,3}, x::AnyTracedRArray{T2,3}, y::AnyTracedRArray{T3,3}
) where {T1,T2,T3}
if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
(size(x, 2) != size(y, 1))
throw(
DimensionMismatch(
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.",
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_mul.",
),
)
end
Expand All @@ -227,7 +246,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe

B = max(size(x, 1), size(y, 1))
out_shape = (B, size(x, 2), size(y, 3))
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(x.mlir_data)))
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data)))

if size(x, 1) != size(y, 1)
if size(x, 1) == 1
Expand All @@ -244,7 +263,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
prec = MLIR.IR.Attribute(
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
)
res = TracedRArray{T,3}(
tmp = TracedRArray{T1,3}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.dot_general(
Expand All @@ -258,7 +277,8 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
),
size(resty),
)
return permutedims(res, (2, 3, 1))
res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data
return res
end

function NNlib.pad_constant(
Expand Down
2 changes: 1 addition & 1 deletion lib/ReactantCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactantCore"
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
version = "0.1.0"
version = "0.1.1"

[deps]
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
Expand Down
8 changes: 5 additions & 3 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ end

MissingTracedValue() = MissingTracedValue(())

const SPECIAL_SYMBOLS = [
:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core
]

# Code generation
"""
@trace <expr>
Expand Down Expand Up @@ -79,7 +83,7 @@ You need to ensure that all branches have the same type.

### Certain Symbols are Reserved

Symbols like `nothing`, `missing` and `:` are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
Symbols like $(SPECIAL_SYMBOLS) are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
example, the following will not work:

```julia
Expand Down Expand Up @@ -299,6 +303,4 @@ function error_if_return(expr)
end
end

const SPECIAL_SYMBOLS = [:(:), :nothing, :missing]

end
7 changes: 5 additions & 2 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,10 @@ function Base.mapreduce(
dims = [dims]
end

op_in_T = Core.Compiler.return_type(f, Tuple{T})

if isnothing(init)
init = Base.reduce_empty(Base.BottomRF(op), Core.Compiler.return_type(f, Tuple{T}))
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
else
init = init::T
end
Expand All @@ -561,7 +563,8 @@ function Base.mapreduce(
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys])

args = (
TracedRNumber{T}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in enumerate(in_tys)
TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, i)) for
(i, ty) in enumerate(in_tys)
)

res = MLIR.IR.block!(fnbody) do
Expand Down
Loading