From b6320eb14979ecd7cb549486789bbb11d13efa13 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 23 Apr 2022 15:24:54 -0700 Subject: [PATCH] Improve type stability of `cat_pullback` --- Project.toml | 2 +- src/rulesets/Base/array.jl | 2 +- test/rulesets/Base/array.jl | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 5ec4d83f5..3cceb9250 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.2" +version = "1.28.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 6af3cde2a..662719350 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -357,7 +357,7 @@ function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims) if d in cdims d > ndimsX ? (prev[d]+1) : (prev[d]+1:prev[d]+sizeX[d]) else - d > ndimsX ? 1 : (:) + d > ndimsX ? 1 : 1:sizeX[d] end end for d in cdims diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 87c352ab1..6ac9c80e8 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -233,6 +233,8 @@ end test_rrule(cat, rand(1), rand(3, 2, 1); fkwargs=(dims=(1,2),), check_inferred=false) # infers Tuple{Zero, Vector{Float64}, Any} test_rrule(cat, rand(2, 2), rand(2, 2)'; fkwargs=(dims=1,)) + # inference on exotic array types + test_rrule(cat, @SArray(rand(3, 2, 1)), @SArray(rand(3, 2, 1)); fkwargs=(dims=Val(2),)) end @testset "hvcat" begin