Skip to content

Commit

Permalink
Fix ∇eachslice output array type
Browse files Browse the repository at this point in the history
  • Loading branch information
BioTurboNick committed Sep 16, 2024
1 parent 30f9b12 commit bcfaf79
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
if i1 === nothing # all slices are Zero!
return _zero_fill!(similar(x, float(eltype(x)), axes(x)))
end
T = promote_type(eltype(dys[i1]), eltype(x))
T = promote_type(eltype.(dys)...)
# The whole point of this gradient is that we can allocate one `dx` array:
dx = similar(x, T, axes(x))
for i in axes(x, dim)
Expand All @@ -282,8 +282,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
end
∇eachslice(dys::AbstractZero, x::AbstractArray, vd::Val{dim}) where {dim} = dys

_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx)))
_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx)
_zero_fill!(dx::AbstractArray) = fill!(dx, zero(eltype(dx)))

function rrule(::typeof(∇eachslice), dys, x, vd::Val)
function ∇∇eachslice(dz_raw)
Expand Down

0 comments on commit bcfaf79

Please sign in to comment.