Skip to content

Commit

Permalink
Allow passing rng to CV fold construction
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDunnNZ committed Oct 26, 2020
1 parent e24214b commit 793ab52
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/GLMNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,11 @@ end

function glmnetcv(X::AbstractMatrix, y::Union{AbstractVector,AbstractMatrix},
family::Distribution=Normal(); weights::Vector{Float64}=ones(length(y)),
rng=Random.GLOBAL_RNG,
nfolds::Int=min(10, div(size(y, 1), 3)),
folds::Vector{Int}=begin
n, r = divrem(size(y, 1), nfolds)
shuffle!([repeat(1:nfolds, outer=n); 1:r])
shuffle!(rng, [repeat(1:nfolds, outer=n); 1:r])
end, parallel::Bool=false, kw...)
# Fit full model once to determine parameters
X = convert(Matrix{Float64}, X)
Expand Down
33 changes: 32 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using GLMNet, Distributions
using SparseArrays, Test
using Random, SparseArrays, Test

X = [74 1 93 93 79 18
98 36 2 27 65 70
Expand Down Expand Up @@ -146,6 +146,16 @@ cv = glmnetcv(X, y; folds=[1,1,1,1,2,2,2,3,3,3])
show(IOBuffer(), cv)
show(IOBuffer(), cv.path)

# Passing RNG makes cv deterministic
cv1 = glmnetcv(X, y)
cv2 = glmnetcv(X, y)
@test cv1.meanloss cv2.meanloss
@test cv1.stdloss cv2.stdloss
cv3 = glmnetcv(X, y; rng=MersenneTwister(1))
cv4 = glmnetcv(X, y; rng=MersenneTwister(1))
@test cv3.meanloss cv4.meanloss
@test cv3.stdloss cv4.stdloss

## LOGISTIC
yl = [(y .< 50) (y .>= 50)]
df_true = [0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
Expand Down Expand Up @@ -302,6 +312,16 @@ cv = glmnetcv(X, yl, Binomial(); folds=[1,1,1,1,2,2,2,3,3,3])
show(IOBuffer(), cv)
show(IOBuffer(), cv.path)

# Passing RNG makes cv deterministic
cv1 = glmnetcv(X, yl, Binomial())
cv2 = glmnetcv(X, yl, Binomial())
@test cv1.meanloss cv2.meanloss
@test cv1.stdloss cv2.stdloss
cv3 = glmnetcv(X, yl, Binomial(); rng=MersenneTwister(1))
cv4 = glmnetcv(X, yl, Binomial(); rng=MersenneTwister(1))
@test cv3.meanloss cv4.meanloss
@test cv3.stdloss cv4.stdloss

# Make sure passing nlambda to glmnetcv works
cv = glmnetcv(X, y, Poisson(); nlambda=2, lambda_min_ratio=0.01)
@test length(cv.lambda) == 2
Expand Down Expand Up @@ -462,6 +482,17 @@ cv = glmnetcv(X, y, Poisson(); folds=[1,1,1,1,2,2,2,3,3,3])
show(IOBuffer(), cv)
show(IOBuffer(), cv.path)

# Passing RNG makes cv deterministic
cv1 = glmnetcv(X, y, Poisson())
cv2 = glmnetcv(X, y, Poisson())
@test cv1.meanloss cv2.meanloss
@test cv1.stdloss cv2.stdloss
cv3 = glmnetcv(X, y, Poisson(); rng=MersenneTwister(1))
cv4 = glmnetcv(X, y, Poisson(); rng=MersenneTwister(1))
@test cv3.meanloss cv4.meanloss
@test cv3.stdloss cv4.stdloss


## COMPRESSEDPREDICTORMATRIX
betas = path.betas
cbetas = convert(Matrix{Float64}, path.betas)
Expand Down

0 comments on commit 793ab52

Please sign in to comment.