Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Crown421 committed Apr 14, 2023
1 parent b1f66ba commit 3ef2c11
Showing 1 changed file with 77 additions and 72 deletions.
149 changes: 77 additions & 72 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ start_time = now()

#-----------------------------------------------------------------------# utils
n = 1000
x, y, z = rand(Bool, n), randn(n), rand(1:10, n)
x, y, z = rand(Bool, n), randn(n), rand(1:10, n)
x2, y2, z2 = rand(Bool, n), randn(n), rand(1:10, n)
xs, ys, zs = vcat(x, x2), vcat(y, y2), vcat(z, z2)

p = 5
xmat, ymat, zmat = rand(Bool, n, p), randn(n, p), rand(1:10, n, p)
xmat, ymat, zmat = rand(Bool, n, p), randn(n, p), rand(1:10, n, p)
xmat2, ymat2, zmat2 = rand(Bool, n, p), randn(n, p), rand(1:10, n, p)


Expand Down Expand Up @@ -47,10 +47,10 @@ end
#-----------------------------------------------------------------------# CallFun
@testset "CallFun" begin
i = 0
o = fit!(CallFun(Mean(), x -> i+=1), y)
o = fit!(CallFun(Mean(), x -> i += 1), y)
@test value(o) mean(y)
@test i == n
@test (mergevals(CallFun(Mean(), x->nothing), y, y2)...)
@test (mergevals(CallFun(Mean(), x -> nothing), y, y2)...)
end
#-----------------------------------------------------------------------# CountMinSketch
@testset "CountMinSketch" begin
Expand Down Expand Up @@ -112,7 +112,7 @@ end
@test OnlineStats.pdf(o, 0.0) 0.3989422804014327
@test OnlineStats.pdf(o, -1.0) 0.24197072451914337
@test OnlineStats.cdf(o, 0.0) 0.5
@test (OnlineStats.cdf(o, -1.0), 0.15865525393145702; atol=.001)
@test (OnlineStats.cdf(o, -1.0), 0.15865525393145702; atol=0.001)
end
@testset "FitMultinomial" begin
o = FitMultinomial(5)
Expand All @@ -127,18 +127,23 @@ end
end
@testset "FitMvNormal" begin
@test value(FitMvNormal(2)) == (zeros(2), Matrix(I, 2, 2))
a, b = mergevals(FitMvNormal(2), OnlineStatsBase.eachrow([y y2]), OnlineStatsBase.eachrow([y2 y]))
@test a[1] b[1]
@test a[2] b[2]
a, b = mergestats(FitMvNormal(2), OnlineStatsBase.eachrow([y y2]), OnlineStatsBase.eachrow([y2 y]))
@test value(a)[1] value(b)[1]
@test value(a)[2] value(b)[2]

@test all(mean(a) .≈ mean(ys))
@test all(var(a) .≈ var(ys))
@test cov(a)[1] cov(a)[4] var(ys)
@test cov(a)[2] cov(a)[3]
end
end
#-----------------------------------------------------------------------# FastNode
@testset "FastNode" begin
X, Y = ymat, rand(1:3, 1000)
X2, Y2 = ymat, rand(1:3, 1000)
data = zip(eachrow(X), Y)
data = zip(eachrow(X), Y)
data2 = zip(eachrow(X2), Y2)
o = fit!(FastNode(5, 3), data)
o = fit!(FastNode(5, 3), data)
o2 = fit!(FastNode(5, 3), data2)
merge!(o, o2)
fit!(o2, data)
Expand All @@ -159,24 +164,24 @@ end
#-----------------------------------------------------------------------# FastTree
@testset "FastTree" begin
X, Y = OnlineStats.fakedata(FastNode, 10^4, 10)
o = fit!(FastTree(10; splitsize=100), zip(eachrow(X),Y))
@test classify(o, X[1,:]) [1, 2]
o = fit!(FastTree(10; splitsize=100), zip(eachrow(X), Y))
@test classify(o, X[1, :]) [1, 2]
@test all(0 .< classify(o, X) .< 3)
@test OnlineStats.nkeys(o) == 2
@test OnlineStats.nvars(o) == 10
@test mean(classify(o, X) .== Y) > .5
@test mean(classify(o, X) .== Y) > 0.5

# Issue 116
Random.seed!(218)
X,Y = OnlineStats.fakedata(FastNode, 10^4, 1)
fit!(FastTree(1, splitsize=100), zip(eachrow(X),Y))
X, Y = OnlineStats.fakedata(FastNode, 10^4, 1)
fit!(FastTree(1, splitsize=100), zip(eachrow(X), Y))
end
#-----------------------------------------------------------------------# FastForest
@testset "FastForest" begin
X, Y = OnlineStats.fakedata(FastNode, 10^4, 10)
o = fit!(FastForest(10; splitsize=500, λ = .7), zip(eachrow(X), Y))
o = fit!(FastForest(10; splitsize=500, λ=0.7), zip(eachrow(X), Y))
@test classify(o, randn(10)) in 1:2
@test mean(classify(o, X) .== Y) > .5
@test mean(classify(o, X) .== Y) > 0.5
end
#-----------------------------------------------------------------------------# GeometricMean
@testset "GeometricMean" begin
Expand All @@ -186,36 +191,36 @@ end
end
#-----------------------------------------------------------------------# HeatMap
@testset "HeatMap" begin
data1 = zip(ymat[:,1], ymat[:,2])
data2 = zip(ymat2[:,1], ymat2[:,2])
@test ==(mergevals(HeatMap(-5:.1:5, -5:.1:5), data1, data2)...)
data1 = zip(ymat[:, 1], ymat[:, 2])
data2 = zip(ymat2[:, 1], ymat2[:, 2])
@test ==(mergevals(HeatMap(-5:0.1:5, -5:0.1:5), data1, data2)...)
@test nobs(HeatMap(data1)) == length(data1)
end
#-----------------------------------------------------------------------# Hist
@testset "Hist" begin
@test ==(mergevals(Hist(-5:.1:5), y, y2)...)
@test ==(mergevals(Hist(-5:0.1:5), y, y2)...)
@testset "Hist compared to StatsBase.Histogram" begin
for edges in (-5:5, collect(-5:5), [-5, -3.5, 0, 1, 4, 5.5])
for data in (y, -6:.75:6)
w = fit(Histogram, data, edges, closed = :left).weights
w2 = fit(Histogram, data, edges, closed = :right).weights
@test fit!(Hist(edges, Number; closed=false, left=true), data).counts == w
for data in (y, -6:0.75:6)
w = fit(Histogram, data, edges, closed=:left).weights
w2 = fit(Histogram, data, edges, closed=:right).weights
@test fit!(Hist(edges, Number; closed=false, left=true), data).counts == w
@test fit!(Hist(edges, Number; closed=false, left=false), data).counts == w2
end
end
end
o = fit!(Hist(-5:.1:5), y)
o = fit!(Hist(-5:0.1:5), y)
for (v1, v2) in zip(extrema(o), extrema(y))
@test (v1, v2; atol=.1)
@test (v1, v2; atol=0.1)
end
@test (mean(o), mean(y); atol=.1)
@test (var(o), var(y); atol=.2)
@test (mean(o), mean(y); atol=0.1)
@test (var(o), var(y); atol=0.2)

# merge unequal bins
r1, r2 = -5:.2:5, -5:.1:5
r1, r2 = -5:0.2:5, -5:0.1:5
@test merge!(fit!(Hist(r1), y), fit!(Hist(r2), y2)) == fit!(Hist(r1), vcat(y, y2))
@test OnlineStats.pdf(fit!(Hist(-5:.1:5), y), 0) > 0
@test OnlineStats.pdf(fit!(Hist(-5:.1:5), y), 100) == 0
@test OnlineStats.pdf(fit!(Hist(-5:0.1:5), y), 0) > 0
@test OnlineStats.pdf(fit!(Hist(-5:0.1:5), y), 100) == 0
end
#-----------------------------------------------------------------------# KHist
@testset "KHist" begin
Expand All @@ -225,29 +230,29 @@ end
@test mean(o) mean(y)
@test var(o) var(y)
@test median(o) median(y)
@test quantile(o) quantile(y, [0, .25, .5, .75, 1])
@test quantile(o) quantile(y, [0, 0.25, 0.5, 0.75, 1])
@test std(o) std(y)
@test extrema(o) == extrema(y)

@test_throws Exception KHist(1)

for (a, b) in [
mergevals(KHist(2000), y, y2),
mergevals(KHist(3), y, y2),
mergevals(KHist(2000, Float32), Float32.(y), Float32.(y2))
]
mergevals(KHist(2000), y, y2),
mergevals(KHist(3), y, y2),
mergevals(KHist(2000, Float32), Float32.(y), Float32.(y2))
]
@test all(ac bc for (ac, bc) in zip(a.centers, b.centers))
@test all(an == bn for (an, bn) in zip(a.counts, b.counts))
end

data = randn(10_000)
o = fit!(KHist(50), data)
@test OnlineStats.pdf(o, -10) == 0.0
@test (OnlineStats.pdf(o, 0.0), 0.3989422804014327, atol=.5)
@test (OnlineStats.pdf(o, 0.0), 0.3989422804014327, atol=0.5)
@test OnlineStats.pdf(o, 10) == 0.0
f = ecdf(o)
@test f(-10) == 0.0
@test (f(0.0), .5; atol=.1)
@test (f(0.0), 0.5; atol=0.1)
@test f(10) == 1.0
# Issue 182
@test f(maximum(data)) == 1.0
Expand Down Expand Up @@ -292,14 +297,14 @@ end
# Issue 116
@test std(KahanVariance()) == 1
@test std(fit!(KahanVariance(), 1)) == 1
@test std(fit!(KahanVariance(), [1, 2])) == sqrt(.5)
@test std(fit!(KahanVariance(), [1, 2])) == sqrt(0.5)
end
#-----------------------------------------------------------------------# KMeans
@testset "KMeans" begin
o = fit!(KMeans(2), eachrow(ymat))
sort!(o, rev=true)
@test o.value[1].n o.value[2].n

x = [repeat([[1.0, 1.0]], 3); repeat([[-1.0, -1.0]], 3)]
o = fit!(KMeans(2), (ξ for ξ x))
@test classify(o, x[1]) classify(o, x[4])
Expand All @@ -310,10 +315,10 @@ end

o = fit!(LinReg(), zip(eachrow(ymat), y))
@test coef(o) ymat \ y
@test coef(o, .1) (ymat'ymat ./ n + .1I) \ ymat'y ./ n
@test coef(o, .1:.1:.5) (ymat'ymat ./ n + Diagonal(.1:.1:.5)) \ ymat'y ./ n
@test coef(o, 0.1) (ymat'ymat ./ n + 0.1I) \ ymat'y ./ n
@test coef(o, 0.1:0.1:0.5) (ymat'ymat ./ n + Diagonal(0.1:0.1:0.5)) \ ymat'y ./ n
@test predict(o, ymat) == ymat * o.β
@test predict(o, ymat[1,:]) == dot(ymat[1,:], o.β)
@test predict(o, ymat[1, :]) == dot(ymat[1, :], o.β)
end
#-----------------------------------------------------------------------# LinRegBuilder
@testset "LinRegBuilder" begin
Expand All @@ -322,23 +327,23 @@ end
o = fit!(LinRegBuilder(), eachrow(ymat))
for i in 1:5
data = ymat[:, setdiff(1:5, i)]
@test coef(o; y=i) [data ones(n)] \ ymat[:,i]
@test coef(o, .1; y=i, bias=false) (data'data ./ n + .1*I) \ data'ymat[:,i] ./ n
@test coef(o; y=i) [data ones(n)] \ ymat[:, i]
@test coef(o, 0.1; y=i, bias=false) (data'data ./ n + 0.1 * I) \ data'ymat[:, i] ./ n
end

o2 = fit!(LinReg(), zip(eachrow(ymat[:,[4,1]]), ymat[:,3]))
@test coef(o, [.2,.4]; y=3, x = [4,1], bias=false) coef(o2, [.2, .4])
o2 = fit!(LinReg(), zip(eachrow(ymat[:, [4, 1]]), ymat[:, 3]))
@test coef(o, [0.2, 0.4]; y=3, x=[4, 1], bias=false) coef(o2, [0.2, 0.4])
end
#-----------------------------------------------------------------------# Mosaic
@testset "Mosaic" begin
@test ==(mergevals(Mosaic(Int,Int), zip(z, z2), zip(z2, z))...)
@test ==(mergevals(Mosaic(Int, Int), zip(z, z2), zip(z2, z))...)
end
#-----------------------------------------------------------------------# MovingTimeWindow
@testset "MovingTimeWindow" begin
dates = Date(2010):Day(1):Date(2011)
data = Int.(1:length(dates))
o = fit!(MovingTimeWindow(Day(4); timetype=Date, valtype=Int), zip(dates, data))
@test value(o) == collect(Pair(a,b) for (a,b) in zip(dates[end-4:end], data[end-4:end]))
@test value(o) == collect(Pair(a, b) for (a, b) in zip(dates[end-4:end], data[end-4:end]))

d1 = zip(dates[1:2], data[1:2])
d2 = zip(dates[3:4], data[3:4])
Expand All @@ -348,16 +353,16 @@ end
@testset "MovingWindow" begin
o = fit!(MovingWindow(10, Int), 1:12)
for i in 1:10
@test o[i] == (1:12)[i + 2]
@test o[i] == (1:12)[i+2]
end
end
#-----------------------------------------------------------------------# NBClassifier
@testset "NBClassifier" begin
o = fit!(NBClassifier(5, Bool), zip(eachrow(ymat),x))
o = fit!(NBClassifier(5, Bool), zip(eachrow(ymat), x))
merge!(o, fit!(NBClassifier(5, Bool), zip(eachrow(ymat2), x2)))
@test nobs(o) == 2000
@test length(probs(o)) == 2
@test sum(predict(o, ymat[1,:])) 1
@test sum(predict(o, ymat[1, :])) 1
@test classify(o, ymat[1, :]) || !classify(o, ymat[1, :])
@test OnlineStats.nvars(o) == 5
@test OnlineStats.nkeys(o) == 2
Expand All @@ -368,7 +373,7 @@ end
@test (mergevals(OrderStats(100), y, y2)...)
o = fit!(OrderStats(n), y)
@test value(o) == sort(y)
@test quantile(o, 0:.25:1) == quantile(y, 0:.25:1)
@test quantile(o, 0:0.25:1) == quantile(y, 0:0.25:1)
end
#-----------------------------------------------------------------------# Partition
@testset "Partition" begin
Expand All @@ -380,7 +385,7 @@ end
@test nobs(o) == nobs(o2)
@test all(nobs.(last.(o.parts)) .== nobs.(last.(o2.parts)))
for i in 1:5
@test value(o.parts[i][2]) value(o2.parts[500 + i][2])
@test value(o.parts[i][2]) value(o2.parts[500+i][2])
end
end
#-----------------------------------------------------------------------# ProbMap
Expand All @@ -393,19 +398,19 @@ end
merge!(o, o2)
fit!(o2, data)
@test sort(collect(keys(o.value))) == sort(collect(keys(o2.value)))
@test probs(fit!(ProbMap(Int), [1,1,2,2,3,3,4,4])) fill(.25, 4)
@test probs(fit!(ProbMap(Int), [1,1,2,2,3,3,4,4]), [1,2,9]) [.5, .5, 0]
@test probs(fit!(ProbMap(Int), [1, 1, 2, 2, 3, 3, 4, 4])) fill(0.25, 4)
@test probs(fit!(ProbMap(Int), [1, 1, 2, 2, 3, 3, 4, 4]), [1, 2, 9]) [0.5, 0.5, 0]
end
#-----------------------------------------------------------------------# Quantile
@testset "Quantile/P2Quantile" begin
data = randn(10_000)
data2 = randn(10_000)
τ = .1:.1:.9
τ = 0.1:0.1:0.9
o = Quantile(τ, b=1000)
@test (value(fit!(copy(o), data)), quantile(data, τ), atol=.1)
@test (value(fit!(copy(o), data)), quantile(data, τ), atol=0.1)

for τi in τ
@test (value(fit!(P2Quantile(τi),data)), quantile(data, τi), atol=.2)
@test (value(fit!(P2Quantile(τi), data)), quantile(data, τi), atol=0.2)
end
end
#-----------------------------------------------------------------------# ReservoirSample
Expand Down Expand Up @@ -439,20 +444,20 @@ end
#-----------------------------------------------------------------------# StatLearn
@testset "StatLearn" begin
X = randn(10_000, 5)
β = collect(-1:.5:1)
β = collect(-1:0.5:1)
Y = X * β + randn(10_000)
Y2 = 2.0 .* [rand()< 1 /(1 + exp(-η)) for η in X*β] .- 1.0
for A in [SGD(),ADAGRAD(),ADAM(),ADAMAX(),ADADELTA(),RMSPROP(),OMAS(),OMAP(),MSPI()]
Y2 = 2.0 .* [rand() < 1 / (1 + exp(-η)) for η in X * β] .- 1.0
for A in [SGD(), ADAGRAD(), ADAM(), ADAMAX(), ADADELTA(), RMSPROP(), OMAS(), OMAP(), MSPI()]
print(" > $A")
print(": ")
for L in [OnlineStats.l2regloss]
print(" | $L")
# sanity checks
for P in [OnlineStats.ElasticNet(.5), abs, abs2, zero]
fit!(StatLearn(A, L, .1; rate=LearningRate(.7), penalty=P), zip(eachrow(X),Y))
for P in [OnlineStats.ElasticNet(0.5), abs, abs2, zero]
fit!(StatLearn(A, L, 0.1; rate=LearningRate(0.7), penalty=P), zip(eachrow(X), Y))
print("")
end
o = fit!(StatLearn(A, L; rate=LearningRate(.7)), zip(eachrow(X),Y))
o = fit!(StatLearn(A, L; rate=LearningRate(0.7)), zip(eachrow(X), Y))
@test o.loss isa typeof(L)
@test o.alg isa typeof(A)
any(isnan.(o.β)) && @info((L, A))
Expand All @@ -463,8 +468,8 @@ end
end
for L in [OnlineStats.logisticloss, OnlineStats.DWDLoss(1.0)]
print(" | $L")
o = fit!(StatLearn(A, L), zip(eachrow(X),Y2))
@test mean(Y2 .== classify(o, X)) > .5
o = fit!(StatLearn(A, L), zip(eachrow(X), Y2))
@test mean(Y2 .== classify(o, X)) > 0.5
end
println()
end
Expand All @@ -483,10 +488,10 @@ include("test_kahan.jl")

#-----------------------------------------------------------------------# Show methods
@testset "Show methods" begin
for stat in [BiasVec([1,2,3]), Bootstrap(Mean()), CallFun(Mean(), println), FastNode(5),
FastTree(5), FastForest(5),
HyperLogLog{10}(), LinRegBuilder(4), KMeans(4), NBClassifier(5, Float64), ProbMap(Int),
P2Quantile(.5), Series(Mean())]
for stat in [BiasVec([1, 2, 3]), Bootstrap(Mean()), CallFun(Mean(), println), FastNode(5),
FastTree(5), FastForest(5),
HyperLogLog{10}(), LinRegBuilder(4), KMeans(4), NBClassifier(5, Float64), ProbMap(Int),
P2Quantile(0.5), Series(Mean())]
println(" > ", stat)
end
end
Expand Down

0 comments on commit 3ef2c11

Please sign in to comment.