Skip to content

Commit 36d50de

Browse files
committed
Don't instrument gemm!() directly, it breaks inference
1 parent 737f246 commit 36d50de

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/gemm.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ gemm_datatype_mappings = (
2727
)
2828
for (gemm, elt) in gemm_datatype_mappings
2929
@eval begin
30-
@inline @timeit_debug to function gemm!(transA::Val, transB::Val,
30+
@inline function gemm!(transA::Val, transB::Val,
3131
M::Int, N::Int, K::Int,
3232
alpha::$(elt), A::Ptr{$elt}, B::Ptr{$elt},
3333
beta::$(elt), C::Ptr{$elt})
@@ -55,4 +55,4 @@ for (gemm, elt) in gemm_datatype_mappings
5555
alpha, A, lda, B, ldb, beta, C, ldc)
5656
end
5757
end
58-
end
58+
end

src/impl/conv_im2col.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ which should eliminate any need for large allocations within this method.
5353
col_ptr = pointer(col)
5454
w_ptr = pointer(w)
5555
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
56-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
56+
@timeit_debug to "gemm!" gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
5757
end
5858
return y
5959
end
@@ -99,7 +99,7 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
9999
col_ptr = pointer(col)
100100
dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1)
101101
dw_ptr = pointer(dw)
102-
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
102+
@timeit_debug to "gemm!" gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
103103

104104
# Because we accumulate over batches in this loop, we must set `beta` equal
105105
# to `1.0` from this point on.
@@ -144,7 +144,7 @@ See the documentation for `conv_im2col!()` for explanation of other parameters.
144144
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
145145
w_ptr = pointer(w)
146146
col_ptr = pointer(col)
147-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
147+
@timeit_debug to "gemm!" gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
148148
@timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
149149
end
150150
return dx
@@ -363,4 +363,4 @@ function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2},
363363
end
364364
end
365365
end
366-
end
366+
end

0 commit comments

Comments
 (0)