Skip to content

Commit 6b987ee

Browse files
authored
Merge pull request #110 from FluxML/sf/gc_preserve
Preserve pointers before sending them into `gemm!()`
2 parents 71f0127 + 6d01875 commit 6b987ee

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

src/impl/conv_im2col.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ which should eliminate any need for large allocations within this method.
5050
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
5151
# doesn't like us putting it on the inside.
5252
@timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
53-
col_ptr = pointer(col)
54-
w_ptr = pointer(w)
55-
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
56-
@timeit_debug to "gemm!" gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
53+
GC.@preserve col, w, y, begin
54+
col_ptr = pointer(col)
55+
w_ptr = pointer(w)
56+
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
57+
@timeit_debug to "gemm!" gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
58+
end
5759
end
5860
return y
5961
end
@@ -96,10 +98,12 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
9698
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
9799
# doesn't like us putting it on the inside.
98100
@timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
99-
col_ptr = pointer(col)
100-
dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1)
101-
dw_ptr = pointer(dw)
102-
@timeit_debug to "gemm!" gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
101+
GC.@preserve col, dw, dy, begin
102+
col_ptr = pointer(col)
103+
dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1)
104+
dw_ptr = pointer(dw)
105+
@timeit_debug to "gemm!" gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
106+
end
103107

104108
# Because we accumulate over batches in this loop, we must set `beta` equal
105109
# to `1.0` from this point on.
@@ -141,10 +145,12 @@ See the documentation for `conv_im2col!()` for explanation of other parameters.
141145
K = channels_out(cdims)
142146

143147
@inbounds for batch_idx in 1:size(dx, 5)
144-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
145-
w_ptr = pointer(w)
146-
col_ptr = pointer(col)
147-
@timeit_debug to "gemm!" gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
148+
GC.@preserve col, w, dy, begin
149+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
150+
w_ptr = pointer(w)
151+
col_ptr = pointer(col)
152+
@timeit_debug to "gemm!" gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
153+
end
148154
@timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
149155
end
150156
return dx

src/impl/depthwiseconv_im2col.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ depthwiseconv_im2col!
3535
# We do a separate convolution for each channel in x, as we must
3636
for c_in in 1:channels_in(cdims)
3737
# Walk each pointer forward as we process each input channel
38-
col_ptr = pointer(col, (c_in-1)*M*K+1)
39-
w_ptr = pointer(w, (c_in-1)*K*N+1)
40-
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
41-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
38+
GC.@preserve col, w, y, begin
39+
col_ptr = pointer(col, (c_in-1)*M*K+1)
40+
w_ptr = pointer(w, (c_in-1)*K*N+1)
41+
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
42+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
43+
end
4244
end
4345
end
4446
return y
@@ -71,11 +73,12 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
7173
# We do a separate convolution for each channel in x, as we must
7274
for c_in in 1:channels_in(cdims)
7375
# Walk each pointer forward as we process each input channel
74-
col_ptr = pointer(col, (c_in - 1)*M*K + 1)
75-
dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1)
76-
dw_ptr = pointer(dw, (c_in - 1)*M*N + 1)
77-
78-
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
76+
GC.@preserve col, dw, dy, begin
77+
col_ptr = pointer(col, (c_in - 1)*M*K + 1)
78+
dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1)
79+
dw_ptr = pointer(dw, (c_in - 1)*M*N + 1)
80+
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
81+
end
7982
end
8083

8184
# Because we accumulate over batches in this loop, we must set `beta` equal
@@ -107,13 +110,15 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
107110
@inbounds for batch_idx in 1:size(dx)[end]
108111
# We do a separate convolution for each channel in x, as we must
109112
for cidx in 1:channels_in(cdims)
110-
# Walk each pointer forward as we process each input channel
111-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
112-
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
113-
col_ptr = pointer(col, (cidx - 1)*M*N + 1)
114-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
113+
GC.@preserve col, w, dy, begin
114+
# Walk each pointer forward as we process each input channel
115+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
116+
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
117+
col_ptr = pointer(col, (cidx - 1)*M*N + 1)
118+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
119+
end
115120
end
116121
@timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
117122
end
118123
return dx
119-
end
124+
end

0 commit comments

Comments
 (0)