@@ -53,8 +53,8 @@ function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where
53
53
return out
54
54
end
55
55
56
- function NNlib. conv (
57
- x :: AnyTracedRArray {T,N} , W:: AnyTracedRArray{T} , cdims:: DenseConvDims
56
+ function NNlib. conv! (
57
+ y :: TracedRArray {T,N} , x :: AnyTracedRArray , W:: AnyTracedRArray , cdims:: DenseConvDims
58
58
) where {T,N}
59
59
x = materialize_traced_array (x)
60
60
W = materialize_traced_array (W)
@@ -83,33 +83,31 @@ function NNlib.conv(
83
83
pl, pr = padding[2 i - 1 ], padding[2 i]
84
84
d = dilation[i]
85
85
s = stride[i]
86
-
87
- (size (x, i) + pl + pr - d * (K - 1 ) - 1 ) ÷ s + 1
86
+ return (size (x, i) + pl + pr - d * (K - 1 ) - 1 ) ÷ s + 1
88
87
end
89
88
output_batch_dim = input_batch_dim
90
89
output_feature_dim = input_feature_dim
91
90
output_spatial_dims = input_spatial_dims
92
91
93
- output_shape = (output_spatial_shapes... , size (W, kernel_output_dim), size (x, N))
94
-
95
- dimension_numbers = """
96
- #stablehlo.conv<raw
97
- input_batch_dimension = $(input_batch_dim - 1 ) ,
98
- input_feature_dimension = $(input_feature_dim - 1 ) ,
99
- input_spatial_dimensions = [$(join (input_spatial_dims .- 1 , " , " )) ],
100
- kernel_output_feature_dimension = $(kernel_output_dim - 1 ) ,
101
- kernel_input_feature_dimension = $(kernel_input_dim - 1 ) ,
102
- kernel_spatial_dimensions = [$(join (kernel_spatial_dims .- 1 , " , " )) ],
103
- output_batch_dimension = $( output_batch_dim - 1 ) ,
104
- output_feature_dimension = $( output_feature_dim - 1 ) ,
105
- output_spatial_dimensions = [$(join (output_spatial_dims .- 1 , " , " )) ],
106
- >"""
107
- dimension_numbers = parse (Reactant. MLIR. IR. Attribute, dimension_numbers)
92
+ # ! format: off
93
+ dimension_numbers = MLIR. API. stablehloConvDimensionNumbersGet (
94
+ MLIR. IR. context (),
95
+ Int64 (input_batch_dim - 1 ),
96
+ Int64 (input_feature_dim - 1 ),
97
+ length (input_spatial_dims), Int64[i - 1 for i in input_spatial_dims],
98
+ Int64 (kernel_input_dim - 1 ),
99
+ Int64 (kernel_output_dim - 1 ),
100
+ length (kernel_spatial_dims), Int64[i - 1 for i in kernel_spatial_dims],
101
+ Int64 (output_batch_dim - 1 ),
102
+ Int64 (output_feature_dim - 1 ),
103
+ length (output_spatial_dims), Int64[i - 1 for i in output_spatial_dims],
104
+ )
105
+ # ! format: on
108
106
109
107
padding = Reactant. MLIR. IR. DenseElementsAttribute (
110
108
reshape (collect (padding), (num_spatial_dims, 2 ))
111
109
)
112
- result_type = Reactant. MLIR. IR. TensorType (output_shape , Reactant. MLIR. IR. Type (T))
110
+ result_type = Reactant. MLIR. IR. TensorType (size (y) , Reactant. MLIR. IR. Type (T))
113
111
114
112
weight = W. mlir_data
115
113
if ! flipkernel
@@ -132,8 +130,8 @@ function NNlib.conv(
132
130
feature_group_count,
133
131
batch_group_count= 1 ,
134
132
)
135
-
136
- return TracedRArray {T,N} ((), Reactant . MLIR . IR . result (conv), output_shape)
133
+ y . mlir_data = Reactant . MLIR . IR . result (conv)
134
+ return y
137
135
end
138
136
139
137
function reduce_window (f, x:: AnyTracedRArray{T,N} , pdims; init) where {T,N}
0 commit comments