Skip to content

Commit 5e5a347

Browse files
simonschoellysimsurace
authored andcommitted
Allow providing an AbstractRNG for seed in spring_layout
1 parent e4ade60 commit 5e5a347

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

src/layout.jl

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using SparseArrays: SparseMatrixCSC, sparse
22
using ArnoldiMethod: SR
33
using Base: OneTo
44
using LinearAlgebra: eigen
5+
using Random: AbstractRNG, default_rng
56

67
"""
78
Position nodes uniformly at random in the unit square.
@@ -83,6 +84,14 @@ where C is a parameter we can adjust
8384
*g*
8485
a graph
8586
87+
*locs_x_in*
88+
x coordinates of the initial locations. If not provided they are sampled
89+
from [-1, 1]. Can be modified.
90+
91+
*locs_y_in*
92+
y coordinates of the initial locations. If not provided they are sampled
93+
from [-1, 1]. Can be modified.
94+
8695
*C*
8796
Constant to fiddle with density of resulting layout
8897
@@ -93,7 +102,8 @@ Number of iterations we apply the forces
93102
Initial "temperature", controls movement per iteration
94103
95104
*seed*
96-
Integer seed for pseudorandom generation of locations (default = 0).
105+
Either an `Integer` seed or an `Random.AbstractRNG` for generation of initial locations.
106+
If neither is provided `Random.default_rng()` is used.
97107
98108
**Examples**
99109
```
@@ -102,13 +112,20 @@ julia> locs_x, locs_y = spring_layout(g)
102112
```
103113
"""
104114
function spring_layout(g::AbstractGraph,
105-
locs_x_in::AbstractVector{R1}=2*rand(nv(g)).-1.0,
106-
locs_y_in::AbstractVector{R2}=2*rand(nv(g)).-1.0;
115+
locs_x_in::AbstractVector{R1},
116+
locs_y_in::AbstractVector{R2};
107117
C=2.0,
108118
MAXITER=100,
109119
INITTEMP=2.0) where {R1 <: Real, R2 <: Real}
110-
111120
nvg = nv(g)
121+
122+
if length(locs_x_in) != nvg
123+
throw(ArgumentError("The length of locs_x_in does not equal the number of vertices"))
124+
end
125+
if length(locs_y_in) != nvg
126+
throw(ArgumentError("The length of locs_y_in does not equal the number of vertices"))
127+
end
128+
112129
adj_matrix = adjacency_matrix(g)
113130

114131
# The optimal distance bewteen vertices
@@ -180,7 +197,14 @@ using Random: MersenneTwister
180197

181198
function spring_layout(g::AbstractGraph, seed::Integer; kws...)
182199
rng = MersenneTwister(seed)
183-
spring_layout(g, 2 .* rand(rng, nv(g)) .- 1.0, 2 .* rand(rng,nv(g)) .- 1.0; kws...)
200+
spring_layout(g, rng; kws...)
201+
end
202+
203+
function spring_layout(g::AbstractGraph, rng::AbstractRNG=default_rng(); kws...)
204+
nvg = nv(g)
205+
locs_x_in = 2.0 * rand(rng, nvg) .- 1.0
206+
locs_y_in = 2.0 * rand(rng, nvg) .- 1.0
207+
spring_layout(g, locs_x_in, locs_y_in; kws...)
184208
end
185209

186210
"""

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
33
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
44
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
56
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
67
VisualRegressionTests = "34922c18-7c2a-561c-bac1-01e79b2c4c92"
78

test/runtests.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Cairo
88
using GraphPlot.Colors
99
using GraphPlot.Compose
1010
using Random
11+
using StableRNGs: StableRNG
1112
using Test
1213
using VisualRegressionTests
1314
using ImageMagick
@@ -125,13 +126,34 @@ end
125126

126127
@testset "Spring Layout" begin
127128
g1 = path_digraph(3)
128-
x1, y1 = spring_layout(g1, 0; C = 1)
129-
# TODO spring_layout uses random values which have changed on higher Julia versions
130-
# we should therefore use StableRNGs.jl for these layouts
131-
@static if VERSION < v"1.7"
132-
@test all(isapprox.(x1, [1.0, -0.014799825222963192, -1.0]))
133-
@test all(isapprox.(y1, [-1.0, 0.014799825222963303, 1.0]))
134-
end
129+
g2 = smallgraph(:house)
130+
131+
# Neither seed nor initial locations provided
132+
x1, y1 = spring_layout(g1; MAXITER=10)
133+
@test length(x1) == nv(g1)
134+
@test length(y1) == nv(g1)
135+
136+
# Using a seed
137+
x2, y2 = spring_layout(g1, 0; C = 1)
138+
@test length(x2) == nv(g1)
139+
@test length(y2) == nv(g1)
140+
141+
# Using a rng
142+
rng = StableRNG(123)
143+
x3, y3 = spring_layout(g2, rng; INITTEMP = 7.5)
144+
@test x3 [0.6417685918857294, -1.0, 1.0, -0.5032029640625139, 0.585415479582793]
145+
@test y3 [-1.0, -0.7760280912987298, 0.06519424728464562, 0.2702599482349506, 1.0]
146+
147+
# Using initial locations
148+
locs_x_in = 1:5
149+
locs_y_in = [-1.0, 2.0, 0.3, 0.4, -0.5]
150+
x4, y4 = spring_layout(g2, locs_x_in, locs_y_in)
151+
@test x4 [-1.0, -0.4030585026962391, -0.050263101475789274, 0.5149349966578818, 1.0]
152+
@test y4 [-0.03307638042475203, 1.0, -0.8197758901868164, 0.15834883764718155, -1.0]
153+
154+
# Providing initial locations with the wrong lengths should throw an ArgumentError
155+
@test_throws ArgumentError("The length of locs_x_in does not equal the number of vertices") spring_layout(g1, 1:5, [1,2,3])
156+
@test_throws ArgumentError("The length of locs_y_in does not equal the number of vertices") spring_layout(g2, 1:5, [1,2,3])
135157
end
136158

137159
@testset "Circular Layout" begin

0 commit comments

Comments
 (0)