Skip to content

Commit

Permalink
Improve type stability of cat_pullback
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir committed Apr 24, 2022
1 parent c294a9a commit b6320eb
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.28.2"
version = "1.28.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
if d in cdims
d > ndimsX ? (prev[d]+1) : (prev[d]+1:prev[d]+sizeX[d])
else
d > ndimsX ? 1 : (:)
d > ndimsX ? 1 : 1:sizeX[d]
end
end
for d in cdims
Expand Down
2 changes: 2 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ end
test_rrule(cat, rand(1), rand(3, 2, 1); fkwargs=(dims=(1,2),), check_inferred=false) # infers Tuple{Zero, Vector{Float64}, Any}

test_rrule(cat, rand(2, 2), rand(2, 2)'; fkwargs=(dims=1,))
# inference on exotic array types
test_rrule(cat, @SArray(rand(3, 2, 1)), @SArray(rand(3, 2, 1)); fkwargs=(dims=Val(2),))
end

@testset "hvcat" begin
Expand Down

0 comments on commit b6320eb

Please sign in to comment.