Skip to content

Commit

Permalink
Sample CSBM
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jul 10, 2023
1 parent 27a2f20 commit 55cfbfd
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"julia.environmentPath": "/home/gdalle/Documents/GitHub/Julia/StochasticBlockModelVariants.jl"
}
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ authors = ["Guillaume Dalle <22795598+gdalle@users.noreply.github.com> and contr
version = "0.1.0"

[deps]
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
Graphs = "1.8"
SimpleWeightedGraphs = "1.4"
julia = "1.9"

[extras]
Expand Down
11 changes: 8 additions & 3 deletions src/StochasticBlockModelVariants.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
module StochasticBlockModelVariants

using Graphs
using LinearAlgebra
using SparseArrays
using SimpleWeightedGraphs: SimpleWeightedGraph
using LinearAlgebra: Symmetric
using Random: AbstractRNG, default_rng
using SparseArrays: SparseMatrixCSC, sparse

export ContextualSBM, ContextualSBMLatents, ContextualSBMObservations

include("contextual_sbm.jl")

end
86 changes: 86 additions & 0 deletions src/contextual_sbm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
struct ContextualSBM{R<:Real}
d::R
λ::R
μ::R
N::Int
P::Int

function ContextualSBM(; d::R1, λ::R2, μ::R3, N, P) where {R1,R2,R3}
R = promote_type(R1, R2, R3)
return new{R}(d, λ, μ, N, P)
end
end

@kwdef struct ContextualSBMLatents{R<:Real}
u::Vector{Int} # (N,)
v::Vector{R} # (P,)
end

@kwdef struct ContextualSBMObservations{R<:Real}
A::Symmetric{Bool,SparseMatrixCSC{Bool,Int}} # (N, N)
G::SimpleWeightedGraph{Int,Bool}
B::Matrix{R} # (P, N)
end

@kwdef struct ContextualSBMMessages
# From variables to factors
χ_node_node
χ_node_feat
χ_feat_node
# From factors to variables
ψ_node_node
ψ_node_feat
ψ_feat_node
end

const CSBM = ContextualSBM
const CSBML = ContextualSBMLatents
const CSBMO = ContextualSBMObservations

function affinities(csbm::CSBM)
(; d, λ) = csbm
cᵢ = d + λ * sqrt(d)
cₒ = d - λ * sqrt(d)
return (; cᵢ, cₒ)
end

nb_nodes(csbm::CSBM) = csbm.N
nb_nodes(latents::CSBML) = length(latents.u)
nb_nodes(obs::CSBMO) = size(obs.A, 1)

nb_features(csbm::CSBM) = csbm.P
nb_features(latents::CSBML) = length(latents.v)
nb_features(obs::CSBMO) = size(obs.B, 1)

function Base.rand(rng::AbstractRNG, csbm::CSBM)
N, P = nb_nodes(csbm), nb_features(csbm)
μ = csbm.μ
(; cᵢ, cₒ) = affinities(csbm)

u = rand(rng, (-1, +1), N)
v = randn(rng, P)

Is, Js = Int[], Int[]
for i in 1:N, j in 1:i
r = rand(rng)
if (u[i] == u[j] && r < cᵢ / N) || (u[i] != u[j] && r < cₒ / N)
push!(Is, i)
push!(Js, j)
end
end
Vs = fill(true, length(Is))
A = Symmetric(sparse(Is, Js, Vs, N, N))
G = SimpleWeightedGraph(A)

Z = randn(rng, P, N)
B = similar(Z)
for α in 1:P, i in 1:N
B[α, i] = sqrt/ N) * v[α] * u[i] * Z[α, i]
end

latents = CSBML(; u, v)
obs = CSBMO(; A, G, B)
return (; latents, obs)
end

Base.rand(csbm::CSBM) = Base.rand(default_rng(), csbm)
5 changes: 5 additions & 0 deletions test/contextual_sbm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using StochasticBlockModelVariants

csbm = ContextualSBM(; d=3, λ=1, μ=2.0, N=10, P=20)

(; latents, obs) = rand(csbm)
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,8 @@ using Test
@testset "Doctests" begin
doctest(StochasticBlockModelVariants)
end

@testset verbose = true "Contextual SBM" begin
include("contextual_sbm.jl")
end
end

0 comments on commit 55cfbfd

Please sign in to comment.