Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Adapt = "2, 3.2"
ChainRulesCore = "0.9.45, 0.10"
ChainRulesCore = "0.9.45, 0.10, 1"
Compat = "3.14"
Requires = "0.5, 1.0"
julia = "1.6"
Expand Down
4 changes: 2 additions & 2 deletions src/batched/batchedadjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =

# Gradients
function rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3})
b_transpose_back(Δ) = (NoTangent(), batched_transpose(Δ))
b_transpose_back(Δ) = (NoTangent(), batched_transpose(unthunk(Δ)))
batched_transpose(A), b_transpose_back
end
function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3})
b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(Δ))
b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(unthunk(Δ)))
batched_adjoint(A), b_adjoint_back
end

Expand Down
3 changes: 2 additions & 1 deletion src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ end
# Gradient, allowing that size(A,3)==1 means it's "broadcasted" out to size(B,3)

function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3})
function batched_mul_pullback(Δ)
function batched_mul_pullback(_Δ)
Δ = unthunk(_Δ)
Athunk = @thunk begin
tmp = batched_mul(Δ, batched_adjoint(B))
size(A,3) == 1 ? sum(tmp, dims=3) : tmp
Expand Down
8 changes: 4 additions & 4 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ for conv in [:conv, :depthwiseconv]
Δ = colmajor(Δ)
return (
NoTangent(),
@thunk($∇conv_data(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(x, Δ, cdims, kw...)),
@thunk($∇conv_data(unthunk(Δ), w, cdims, kw...)),
@thunk($∇conv_filter(x, unthunk(Δ), cdims, kw...)),
NoTangent(),
)
end
Expand All @@ -323,8 +323,8 @@ for conv in [:conv, :depthwiseconv]
Δ = colmajor(Δ)
return (
NoTangent(),
@thunk($conv(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(Δ, x, cdims, kw...)),
@thunk($conv(unthunk(Δ), w, cdims, kw...)),
@thunk($∇conv_filter(unthunk(Δ), x, cdims, kw...)),
NoTangent(),
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ end
function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
y = gather!(dst, src, idx)
src_size = size(src)
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(Δ, src_size, idx), NoTangent())
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent())
y, gather!_pullback
end
2 changes: 1 addition & 1 deletion src/padding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ function rrule(::typeof(pad_constant), x::AbstractArray{T,N},
function pad_constant_pullback(Δ)
p = gen_pad(pad, dims, N)
outsize, center = size_and_center(x, p)
(NoTangent(), @thunk(Δ[center...]), NoTangent(), NoTangent(),)
(NoTangent(), @thunk(unthunk(Δ)[center...]), NoTangent(), NoTangent(),)
end
return y, pad_constant_pullback
end
Expand Down
2 changes: 1 addition & 1 deletion src/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ for pool in [:maxpool, :meanpool]
pullback = Symbol(pool, :_pullback)
@eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...)
Ω = $pool(x, pdims; kw...)
$pullback(Δ) = (NoTangent(), $∇pool(Δ, Ω, x, pdims; kw...), NoTangent())
$pullback(Δ) = (NoTangent(), $∇pool(unthunk(Δ), Ω, x, pdims; kw...), NoTangent())
return Ω, $pullback
end
end
4 changes: 2 additions & 2 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,12 @@ end
function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dst_old = copy(dst)
scatter!(op, dst, src, idx)
scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, Δ, dst_old, dst), ∇scatter!_src(op, Δ, dst, src, idx), NoTangent())
scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, unthunk(Δ), dst_old, dst), ∇scatter!_src(op, unthunk(Δ), dst, src, idx), NoTangent())
dst, scatter!_pullback
end

function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray)
y = scatter(op, src, idx)
scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, Δ, y, src, idx), NoTangent())
scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, unthunk(Δ), y, src, idx), NoTangent())
y, scatter_pullback
end
4 changes: 2 additions & 2 deletions src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

function rrule(::typeof(softmax), xs; dims=1)
y = softmax(xs; dims=dims)
softmax_pullback(Δ) = (NoTangent(), ∇softmax(Δ, xs, y, dims=dims))
softmax_pullback(Δ) = (NoTangent(), ∇softmax(unthunk(Δ), xs, y, dims=dims))
return y, softmax_pullback
end

Expand Down Expand Up @@ -125,7 +125,7 @@ end

function rrule(::typeof(logsoftmax), xs; dims=1)
y = logsoftmax(xs; dims=dims)
logsoftmax_pullback(Δ) = (NoTangent(), ∇logsoftmax(Δ, xs, y, dims=dims))
logsoftmax_pullback(Δ) = (NoTangent(), ∇logsoftmax(unthunk(Δ), xs, y, dims=dims))
return y, logsoftmax_pullback
end

Expand Down
8 changes: 4 additions & 4 deletions src/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ end

function rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple)
Ω = upsample_nearest(x, s)
upsample_nearest_pullback(Δ) = (NoTangent(), ∇upsample_nearest(Δ, s), NoTangent())
upsample_nearest_pullback(Δ) = (NoTangent(), ∇upsample_nearest(unthunk(Δ), s), NoTangent())
return Ω, upsample_nearest_pullback
end

Expand Down Expand Up @@ -203,7 +203,7 @@ end
function rrule(::typeof(upsample_linear), x; size)
Ω = upsample_linear(x; size=size)
function upsample_linear_pullback(Δ)
(NoTangent(), ∇upsample_linear(Δ; size=Base.size(x,1)))
(NoTangent(), ∇upsample_linear(unthunk(Δ); size=Base.size(x,1)))
end
return Ω, upsample_linear_pullback
end
Expand Down Expand Up @@ -368,7 +368,7 @@ end
function rrule(::typeof(upsample_bilinear), x; size)
Ω = upsample_bilinear(x; size=size)
function upsample_bilinear_pullback(Δ)
(NoTangent(), ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))))
(NoTangent(), ∇upsample_bilinear(unthunk(Δ); size=(Base.size(x,1),Base.size(x,2))))
end
return Ω, upsample_bilinear_pullback
end
Expand Down Expand Up @@ -518,7 +518,7 @@ end
function rrule(::typeof(upsample_trilinear), x; size)
Ω = upsample_trilinear(x; size=size)
function upsample_trilinear_pullback(Δ)
(NoTangent(), ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
(NoTangent(), ∇upsample_trilinear(unthunk(Δ); size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
end
return Ω, upsample_trilinear_pullback
end
Expand Down