Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Specialized _var and _std functions #612

Merged
merged 4 commits into from
Mar 13, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
AbstractFFTs = "0.4, 0.5"
Expand Down
1 change: 1 addition & 0 deletions src/CuArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include("mapreduce.jl")
include("accumulate.jl")
include("linalg.jl")
include("nnlib.jl")
include("statistics.jl")

# vendor libraries
include("blas/CUBLAS.jl")
Expand Down
13 changes: 13 additions & 0 deletions src/statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import Statistics

Statistics._var(A::CuArray, corrected::Bool, mean, dims) =
sum((A .- something(mean, Statistics.mean(A, dims=dims))).^2, dims=dims)/(prod(size(A)[[dims...]])-corrected)

Statistics._var(A::CuArray, corrected::Bool, mean, ::Colon) =
sum((A .- something(mean, Statistics.mean(A))).^2)/(length(A)-corrected)

Statistics._std(A::CuArray, corrected::Bool, mean, dims) =
sqrt.(Statistics.var(A; corrected=corrected, mean=mean, dims=dims))

Statistics._std(A::CuArray, corrected::Bool, mean, ::Colon) =
sqrt.(Statistics.var(A; corrected=corrected, mean=mean, dims=:))
16 changes: 16 additions & 0 deletions test/statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@testset "Statistics" begin

using CuArrays
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this import.

using Statistics

@testset "std" begin
A = rand(10, 1, 2)
cuA = cu(A)

@test std(cuA[:, 1, 1]) ≈ std(A[:, 1, 1])
maleadt marked this conversation as resolved.
Show resolved Hide resolved
@test std(cuA) ≈ std(A)
@test std(cuA, corrected=true) ≈ std(cuA, corrected=true, dims=:) ≈ std(A, corrected=true)
@test collect(std(cuA, dims=1)) ≈ std(A, dims=1)
end
maleadt marked this conversation as resolved.
Show resolved Hide resolved

end