Skip to content

Commit 3102dfa

Browse files
committed
refactor: directly overload inplace conv routine from NNlib
1 parent 6deb95d commit 3102dfa

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where
5353
return out
5454
end
5555

56-
function NNlib.conv(
57-
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
56+
function NNlib.conv!(
57+
y::TracedRArray{T,N}, x::AnyTracedRArray, W::AnyTracedRArray, cdims::DenseConvDims
5858
) where {T,N}
5959
x = materialize_traced_array(x)
6060
W = materialize_traced_array(W)
@@ -83,33 +83,31 @@ function NNlib.conv(
8383
pl, pr = padding[2i - 1], padding[2i]
8484
d = dilation[i]
8585
s = stride[i]
86-
87-
(size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
86+
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
8887
end
8988
output_batch_dim = input_batch_dim
9089
output_feature_dim = input_feature_dim
9190
output_spatial_dims = input_spatial_dims
9291

93-
output_shape = (output_spatial_shapes..., size(W, kernel_output_dim), size(x, N))
94-
95-
dimension_numbers = """
96-
#stablehlo.conv<raw
97-
input_batch_dimension = $(input_batch_dim - 1),
98-
input_feature_dimension = $(input_feature_dim - 1),
99-
input_spatial_dimensions = [$(join(input_spatial_dims .- 1, ", "))],
100-
kernel_output_feature_dimension = $(kernel_output_dim - 1),
101-
kernel_input_feature_dimension = $(kernel_input_dim - 1),
102-
kernel_spatial_dimensions = [$(join(kernel_spatial_dims .- 1, ", "))],
103-
output_batch_dimension = $( output_batch_dim - 1 ),
104-
output_feature_dimension = $( output_feature_dim - 1),
105-
output_spatial_dimensions = [$(join(output_spatial_dims .- 1, ", "))],
106-
>"""
107-
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)
92+
#! format: off
93+
dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
94+
MLIR.IR.context(),
95+
Int64(input_batch_dim - 1),
96+
Int64(input_feature_dim - 1),
97+
length(input_spatial_dims), Int64[i - 1 for i in input_spatial_dims],
98+
Int64(kernel_input_dim - 1),
99+
Int64(kernel_output_dim - 1),
100+
length(kernel_spatial_dims), Int64[i - 1 for i in kernel_spatial_dims],
101+
Int64(output_batch_dim - 1),
102+
Int64(output_feature_dim - 1),
103+
length(output_spatial_dims), Int64[i - 1 for i in output_spatial_dims],
104+
)
105+
#! format: on
108106

109107
padding = Reactant.MLIR.IR.DenseElementsAttribute(
110108
reshape(collect(padding), (num_spatial_dims, 2))
111109
)
112-
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
110+
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))
113111

114112
weight = W.mlir_data
115113
if !flipkernel
@@ -132,8 +130,8 @@ function NNlib.conv(
132130
feature_group_count,
133131
batch_group_count=1,
134132
)
135-
136-
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
133+
y.mlir_data = Reactant.MLIR.IR.result(conv)
134+
return y
137135
end
138136

139137
function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}

0 commit comments

Comments
 (0)