@@ -35,10 +35,12 @@ depthwiseconv_im2col!
35
35
# We do a separate convolution for each channel in x, as we must
36
36
for c_in in 1 : channels_in (cdims)
37
37
# Walk each pointer forward as we process each input channel
38
- col_ptr = pointer (col, (c_in- 1 )* M* K+ 1 )
39
- w_ptr = pointer (w, (c_in- 1 )* K* N+ 1 )
40
- y_ptr = pointer (y, ((batch_idx - 1 )* channels_in (cdims) + c_in - 1 )* M* N + 1 )
41
- gemm! (Val (false ), Val (false ), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
38
+ GC. @preserve col, w, y, begin
39
+ col_ptr = pointer (col, (c_in- 1 )* M* K+ 1 )
40
+ w_ptr = pointer (w, (c_in- 1 )* K* N+ 1 )
41
+ y_ptr = pointer (y, ((batch_idx - 1 )* channels_in (cdims) + c_in - 1 )* M* N + 1 )
42
+ gemm! (Val (false ), Val (false ), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
43
+ end
42
44
end
43
45
end
44
46
return y
@@ -71,11 +73,12 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
71
73
# We do a separate convolution for each channel in x, as we must
72
74
for c_in in 1 : channels_in (cdims)
73
75
# Walk each pointer forward as we process each input channel
74
- col_ptr = pointer (col, (c_in - 1 )* M* K + 1 )
75
- dy_ptr = pointer (dy, (batch_idx - 1 )* N* K* channels_in (cdims) + (c_in - 1 )* K* N + 1 )
76
- dw_ptr = pointer (dw, (c_in - 1 )* M* N + 1 )
77
-
78
- gemm! (Val (true ), Val (false ), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
76
+ GC. @preserve col, dw, dy, begin
77
+ col_ptr = pointer (col, (c_in - 1 )* M* K + 1 )
78
+ dy_ptr = pointer (dy, (batch_idx - 1 )* N* K* channels_in (cdims) + (c_in - 1 )* K* N + 1 )
79
+ dw_ptr = pointer (dw, (c_in - 1 )* M* N + 1 )
80
+ gemm! (Val (true ), Val (false ), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
81
+ end
79
82
end
80
83
81
84
# Because we accumulate over batches in this loop, we must set `beta` equal
@@ -107,13 +110,15 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
107
110
@inbounds for batch_idx in 1 : size (dx)[end ]
108
111
# We do a separate convolution for each channel in x, as we must
109
112
for cidx in 1 : channels_in (cdims)
110
- # Walk each pointer forward as we process each input channel
111
- dy_ptr = pointer (dy, (batch_idx - 1 )* M* K* channels_in (cdims)+ (cidx - 1 )* K* M + 1 )
112
- w_ptr = pointer (w, (cidx - 1 )* K* N + 1 )
113
- col_ptr = pointer (col, (cidx - 1 )* M* N + 1 )
114
- gemm! (Val (false ), Val (true ), M, N, K, alpha, dy_ptr, w_ptr, T (0 ), col_ptr)
113
+ GC. @preserve col, w, dy, begin
114
+ # Walk each pointer forward as we process each input channel
115
+ dy_ptr = pointer (dy, (batch_idx - 1 )* M* K* channels_in (cdims)+ (cidx - 1 )* K* M + 1 )
116
+ w_ptr = pointer (w, (cidx - 1 )* K* N + 1 )
117
+ col_ptr = pointer (col, (cidx - 1 )* M* N + 1 )
118
+ gemm! (Val (false ), Val (true ), M, N, K, alpha, dy_ptr, w_ptr, T (0 ), col_ptr)
119
+ end
115
120
end
116
121
@timeit_debug to " col2im!" col2im! (view (dx, :, :, :, :, batch_idx), col, cdims)
117
122
end
118
123
return dx
119
- end
124
+ end
0 commit comments