Skip to content

Commit 56a576c

Browse files
authored
Merge pull request #337 from JuliaGPU/tb/var_type_stabilit
Fix type stability of Statistics.var with dims.
2 parents bd5a2a8 + ae1283d commit 56a576c

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

src/statistics.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
using Statistics
22

33
Statistics._var(A::CuArray, corrected::Bool, mean, dims) =
4-
sum((A .- something(mean, Statistics.mean(A, dims=dims))).^2, dims=dims)/(prod(size(A)[[dims...]])-corrected)
4+
sum((A .- something(mean, Statistics.mean(A, dims=dims))).^2, dims=dims)/(prod(size(A)[dims])-corrected)
5+
6+
Statistics._var(A::CuArray, corrected::Bool, mean, dim::Integer) =
7+
sum((A .- something(mean, Statistics.mean(A, dims=dim))).^2, dims=dim)/(size(A)[dim]-corrected)
58

69
Statistics._var(A::CuArray, corrected::Bool, mean, ::Colon) =
710
sum((A .- something(mean, Statistics.mean(A))).^2)/(length(A)-corrected)

test/statistics.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ end
1717
@testset "mean" begin
1818
@test testf(mean, rand(2,2))
1919
@test testf(mean, rand(2,2); dims=2)
20+
@test testf(mean, rand(2,2,2); dims=[1,3])
2021
@test testf(x->mean(sin, x), rand(2,2))
2122
@test testf(x->mean(sin, x; dims=2), rand(2,2))
23+
@test testf(x->mean(sin, x; dims=[1,3]), rand(2,2,2))
2224
end

0 commit comments

Comments
 (0)