diff --git a/.gitignore b/.gitignore index 35aab2d..9c2b76c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ notebook/.ipynb_checkpoints -notebook/.xdp-final_mus.png-JsubYP \ No newline at end of file +notebook/.xdp-final_mus.png-JsubYP +.DS_Store diff --git a/Manifest.toml b/Manifest.toml index 488d59e..1f86484 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,14 +1,14 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.0" +julia_version = "1.10.4" manifest_format = "2.0" project_hash = "420c20e86825ec6ad590fa9d0d7127e92e83afa7" [[deps.AliasTables]] -deps = ["Random"] -git-tree-sha1 = "ca95b2220ef440817963baa71525a8f2f4ae7f8f" +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" -version = "1.0.0" +version = "1.1.3" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -26,17 +26,11 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" - [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.14.0" +version = "4.16.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -45,7 +39,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.2+0" +version = "1.1.1+0" [[deps.DataAPI]] git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" @@ -68,9 +62,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "22c595ca4146c07b16bcf9c8bea86f731f7109d2" +git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.108" +version = "0.25.112" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -93,20 +87,14 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra"] -git-tree-sha1 = "881275fc6b8c6f0dfb9cfa4a878979a33cb26be3" +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.10.1" +version = "1.13.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -116,20 +104,20 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "3863330da5466410782f2bffc64f3d505a6a8334" +git-tree-sha1 = "1dc470db8b1131cfc7fb4c115de89fe391b9e780" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.10.0" +version = "1.12.0" [[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "7c4195be1649ae622304031ed46a2f4df989f1eb" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" +version = "0.3.24" [[deps.Inflate]] -git-tree-sha1 = "ea8031dea4aff6bd41f1df8f2fdfb25b33626381" +git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.4" +version = "0.1.5" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -142,28 +130,33 @@ version = "0.2.2" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" +version = "1.6.0" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" +version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" +version = "8.4.0+0" [[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" +version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -174,9 +167,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" +version = "0.3.28" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -204,7 +197,7 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" +version = "2.28.2+1" [[deps.Missings]] deps = ["DataAPI"] @@ -217,13 +210,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" +version = "2023.1.10" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -232,12 +219,12 @@ version = "1.2.0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" +version = "0.3.23+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" +version = "0.8.1+2" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -259,7 +246,7 @@ version = "0.11.31" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.0" +version = "1.10.0" [[deps.PrecompileTools]] deps = ["Preferences"] @@ -277,18 +264,29 @@ version = "1.4.3" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[[deps.PtrArrays]] +git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.2.1" + [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" +git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.9.4" +version = "2.11.1" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.Random]] -deps = ["SHA", "Serialization"] +deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[deps.Reexport]] @@ -298,15 +296,15 @@ version = "1.2.2" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" +version = "0.8.0" [[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" +version = "0.5.1+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -337,12 +335,13 @@ version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" +version = "2.4.0" [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -352,9 +351,9 @@ version = "2.3.1" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.3" +version = "1.9.7" [deps.StaticArrays.extensions] StaticArraysChainRulesCoreExt = "ChainRulesCore" @@ -365,14 +364,14 @@ version = "1.9.3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" +version = "1.4.3" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" +version = "1.10.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -388,9 +387,9 @@ version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.1" +version = "1.3.2" [deps.StatsFuns.extensions] StatsFunsChainRulesCoreExt = "ChainRulesCore" @@ -405,9 +404,9 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" +version = "7.2.1+1" [[deps.TOML]] deps = ["Dates"] @@ -429,19 +428,19 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" +version = "1.2.13+1" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.7.0+0" +version = "5.8.0+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" +version = "1.52.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" +version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index 81f5262..fc32dcd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SmallCouplingDynamicCavity" uuid = "1607259e-80f4-4675-b495-cb9c54bacb3f" authors = ["Mattia tarabolo "] -version = "3.2.1" +version = "4.0.1" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/message_passing_func.jl b/src/message_passing_func.jl index 81f3174..3eb1fea 100644 --- a/src/message_passing_func.jl +++ b/src/message_passing_func.jl @@ -1,40 +1,47 @@ function update_single_message!( + ε::Float64, jnode::Node{TI,TG}, iindex::Int, ρ::FBm, M::Array{Float64,3}, - updmess::Updmess, newmess::Message, damp::Float64, - sumargexp::SumM, inode::Node{TI,TG}, μ_cutoff::Float64) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - clear!(updmess, newmess) + #clear!(newmess) - updmess.lognumm .= log.(ρ.fwm) .+ log.(ρ.bwm) - updmess.signμ .= sign.(ρ.bwm[1, 2:end] - ρ.bwm[2, 2:end]) - updmess.lognumμ .= log.(ρ.fwm[1, 1:end-1]) .+ log.(M[1, 1, :]) .+ log.(abs.(ρ.bwm[1, 2:end] - ρ.bwm[2, 2:end])) - updmess.logZ .= log.(dropdims(sum(ρ.fwm .* ρ.bwm, dims=1), dims=1)) + @inbounds @fastmath for t in 1:inode.model.T + normmess = 0.0 + @inbounds @fastmath for x in 1:n_states(inode.model.Disease) + normmess += ρ.fwm[x,t] * ρ.bwm[x,t] + end + newmess.m[t] = ρ.fwm[2,t] * ρ.bwm[2,t] / normmess + newmess.μ[t] = max(ρ.fwm[1,t] * M[1,1,t] * (ρ.bwm[1,t+1] - ρ.bwm[2,t+1]) / normmess, μ_cutoff) + check_mess(newmess.m[t], newmess.μ[t], normmess, t) - newmess.m .= exp.(updmess.lognumm[2, :] .- updmess.logZ) - newmess.μ .= max.(updmess.signμ .* exp.(updmess.lognumμ .- updmess.logZ[1:end-1]), μ_cutoff) + newmess.m[t] = jnode.cavities[iindex].m[t]*damp + newmess.m[t]*(1 - damp) + newmess.μ[t] = jnode.cavities[iindex].μ[t]*damp + newmess.μ[t]*(1 - damp) - newmess.m .= jnode.cavities[iindex].m.*damp .+ newmess.m.*(1 - damp) - newmess.μ .= jnode.cavities[iindex].μ.*damp .+ newmess.μ.*(1 - damp) + ε = max(ε, abs(newmess.m[t] - jnode.cavities[iindex].m[t])) - if any(!isfinite, newmess.m) || any(!isfinite, newmess.μ) - throw(DomainError("NaN evaluated when updating message!")) + jnode.cavities[iindex].m[t] = newmess.m[t] + jnode.cavities[iindex].μ[t] = newmess.μ[t] end - ε = normupdate(jnode.cavities[iindex].m, newmess.m)#, normupdate(jnode.cavities[iindex].μ, newmess.μ)) - - jnode.cavities[iindex].m .= newmess.m - jnode.cavities[iindex].μ .= newmess.μ - - return ε + # t = T+1 + normmess = 0.0 + @inbounds @fastmath for x in 1:n_states(inode.model.Disease) + normmess += ρ.fwm[x,inode.model.T+1] * ρ.bwm[x,inode.model.T+1] + end + newmess.m[inode.model.T+1] = ρ.fwm[2,inode.model.T+1] * ρ.bwm[2,inode.model.T+1] / normmess + check_mess(newmess.m[inode.model.T+1], 0.0, normmess, inode.model.T+1) + newmess.m[inode.model.T+1] = jnode.cavities[iindex].m[inode.model.T+1]*damp + newmess.m[inode.model.T+1]*(1 - damp) + ε = max(ε, abs(newmess.m[inode.model.T+1] - jnode.cavities[iindex].m[inode.model.T+1])) + jnode.cavities[iindex].m[inode.model.T+1] = newmess.m[inode.model.T+1] end + function compute_ρ!( inode::Node{TI,TG}, iindex::Int, @@ -47,26 +54,30 @@ function compute_ρ!( T::Int, infectionmodel::TI) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - clear!(M, ρ) + #clear!(M, ρ) - ρ.fwm[:, 1] .= prior[:, inode.i] - ρ.bwm[:, T+1] .= inode.obs[:, T+1] + @inbounds @fastmath for x in 1:n_states(infectionmodel) + ρ.fwm[x, 1] = prior[x, inode.i] + ρ.bwm[x, T+1] = inode.obs[x, T+1] + end fill_transmat_cav!(M, inode, iindex, jnode, jindex, sumargexp, infectionmodel) # fwd-bwd update - for t in 1:T - ρ.fwm[:, t+1] .= (ρ.fwm[:, t]' * M[:, :, t])' - ρ.bwm[:, T+1-t] .= M[:, :, T+1-t] * ρ.bwm[:, T+2-t] - end - - if any(!isfinite, ρ.fwm) || any(!isfinite, ρ.bwm) - throw(DomainError("NaN evaluated when computing ρ!")) + @inbounds @fastmath for t in 1:T + @inbounds @fastmath for x1 in 1:n_states(infectionmodel) + ρ.fwm[x1, t+1] = 0.0 + ρ.bwm[x1, T+1-t] = 0.0 + @inbounds @fastmath for x2 in 1:n_states(infectionmodel) + ρ.fwm[x1, t+1] += ρ.fwm[x2, t] * M[x2, x1, t] + ρ.bwm[x1, T+1-t] += ρ.bwm[x2, T+2-t] * M[x1, x2, T+1-t] + end + end + check_ρ(inode, ρ, M, t, T) end - - return M, ρ end + function update_single_marginal!( inode::Node{TI,TG}, nodes::Vector{Node{TI,TG}}, @@ -75,39 +86,52 @@ function update_single_marginal!( ρ::FBm, prior::Array{Float64, 2}, T::Int, - updmess::Updmess, - newmarg::Marginal, - μ_cutoff::Float64, infectionmodel::TI) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} compute_sumargexp!(inode, nodes, sumargexp) - clear!(M, ρ) + #clear!(M, ρ) - ρ.fwm[:, 1] .= prior[:, inode.i] - ρ.bwm[:, T+1] .= inode.obs[:, T+1] + @inbounds @fastmath for x in 1:n_states(infectionmodel) + ρ.fwm[x, 1] = prior[x, inode.i] + ρ.bwm[x, T+1] = inode.obs[x, T+1] + end fill_transmat_marg!(M, inode, sumargexp, infectionmodel) # fwd-bwd update - for t in 1:T - ρ.fwm[:, t+1] .= (ρ.fwm[:, t]' * M[:, :, t])' - ρ.bwm[:, T+1-t] .= M[:, :, T+1-t] * ρ.bwm[:, T+2-t] + @inbounds @fastmath for t in 1:T + @inbounds @fastmath for x1 in 1:n_states(infectionmodel) + ρ.fwm[x1, t+1] = 0.0 + ρ.bwm[x1, T+1-t] = 0.0 + @inbounds @fastmath for x2 in 1:n_states(infectionmodel) + ρ.fwm[x1, t+1] += ρ.fwm[x2, t] * M[x2, x1, t] + ρ.bwm[x1, T+1-t] += ρ.bwm[x2, T+2-t] * M[x1, x2, T+1-t] + end + end end - clear!(updmess, newmarg) - - updmess.lognumm .= log.(ρ.fwm) .+ log.(ρ.bwm) - updmess.signμ .= sign.(ρ.bwm[1, 2:end] - ρ.bwm[2, 2:end]) - updmess.lognumμ .= log.(ρ.fwm[1, 1:end-1]) .+ log.(M[1, 1, :]) .+ log.(abs.(ρ.bwm[1, 2:end] - ρ.bwm[2, 2:end])) - updmess.logZ .= log.(dropdims(sum(ρ.fwm .* ρ.bwm, dims=1), dims=1)) - - newmarg.m .= exp.(updmess.lognumm .- updmess.logZ') - newmarg.μ .= max.(updmess.signμ .* exp.(updmess.lognumμ .- updmess.logZ[1:end-1]), μ_cutoff) + @inbounds @fastmath for t in 1:inode.model.T + normmarg = 0.0 + @inbounds @fastmath for x in 1:n_states(infectionmodel) + normmarg += ρ.fwm[x,t] * ρ.bwm[x,t] + end + @inbounds @fastmath for x in 1:n_states(infectionmodel) + inode.marg.m[x,t] = ρ.fwm[x,t] * ρ.bwm[x,t] / normmarg + end + end - return newmarg + # t = T+1 + normmarg = 0.0 + @inbounds @fastmath for x in 1:n_states(infectionmodel) + normmarg += ρ.fwm[x,T+1] * ρ.bwm[x,T+1] + end + @inbounds @fastmath for x in 1:n_states(infectionmodel) + inode.marg.m[x,T+1] = ρ.fwm[x,T+1] * ρ.bwm[x,T+1] / normmarg + end end + function compute_sumargexp!( inode::Node{TI,TG}, nodes::Vector{Node{TI,TG}}, @@ -115,16 +139,18 @@ function compute_sumargexp!( clear!(sumargexp) - for (kindex, k) in enumerate(inode.∂) + @inbounds @fastmath for (kindex, k) in enumerate(inode.∂) iindex = nodes[k].∂_idx[inode.i] - sumargexp.summ .+= inode.cavities[kindex].m[1:end-1] .* inode.νs[kindex] #chiedere ad anna se è più veloce riga o colonna - sumargexp.sumμ .+= inode.cavities[kindex].μ .* nodes[k].νs[iindex] + @inbounds @fastmath for t in 1:inode.model.T + sumargexp.summ[t] += inode.cavities[kindex].m[t] * inode.νs[kindex][t] + sumargexp.sumμ[t] += inode.cavities[kindex].μ[t] * nodes[k].νs[iindex][t] + end end - - return sumargexp end + function update_node!( + ε::Float64, inode::Node{TI,TG}, nodes::Vector{Node{TI,TG}}, sumargexp::SumM, @@ -132,36 +158,30 @@ function update_node!( ρ::FBm, prior::Array{Float64, 2}, T::Int, - updmess::Updmess, newmess::Message, - newmarg::Marginal, damp::Float64, μ_cutoff::Float64, infectionmodel::TI) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - ε = 0.0 - - sumargexp = compute_sumargexp!(inode, nodes, sumargexp) + compute_sumargexp!(inode, nodes, sumargexp) for (jindex, j) in enumerate(inode.∂) iindex = nodes[j].∂_idx[inode.i] - M, ρ = compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, T, infectionmodel) - ε = max(ε, update_single_message!(nodes[j], iindex, ρ, M, updmess, newmess, damp, sumargexp, inode, μ_cutoff)) + compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, T, infectionmodel) + update_single_message!(ε, nodes[j], iindex, ρ, M, newmess, damp, inode, μ_cutoff) end - - return ε end + function update_cavities!( + ε::Float64, nodes::Vector{Node{TI,TG}}, sumargexp::SumM, M::Array{Float64,3}, ρ::FBm, prior::Array{Float64,2}, T::Int, - updmess::Updmess, newmess::Message, - newmarg::Marginal, damp::Float64, μ_cutoff::Float64, infectionmodel::TI, @@ -170,12 +190,11 @@ function update_cavities!( ε = 0.0 for inode in shuffle(rng, nodes) - ε = max(ε, update_node!(inode, nodes, sumargexp, M, ρ, prior, T, updmess, newmess, newmarg, damp, μ_cutoff, infectionmodel)) + update_node!(ε, inode, nodes, sumargexp, M, ρ, prior, T, newmess, damp, μ_cutoff, infectionmodel) end - - return ε end + function compute_marginals!( nodes::Vector{Node{TI,TG}}, sumargexp::SumM, @@ -183,18 +202,15 @@ function compute_marginals!( ρ::FBm, T::Int64, prior::Array{Float64,2}, - updmess::Updmess, - newmarg::Marginal, - μ_cutoff::Float64, infectionmodel::TI, rng::AbstractRNG) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} for inode in shuffle(rng, nodes) - newmarg = update_single_marginal!(inode, nodes, sumargexp, M, ρ, prior, T, updmess, newmarg, μ_cutoff, infectionmodel) - inode.marg.m .= newmarg.m + update_single_marginal!(inode, nodes, sumargexp, M, ρ, prior, T, infectionmodel) end end + """ run_SCDC( model::EpidemicModel{TI,TG}, @@ -207,7 +223,7 @@ end n_iter_nc::Int64 = 1, damp_nc::Float64 = 0.0, callback::Function=(x...) -> nothing - rng::AbstractRNG=Xoshiro()) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} + rng::AbstractRNG=Xoshiro(1234)) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} Runs the Small Coupling Dynamic Cavity (SCDC) inference algorithm. @@ -241,12 +257,14 @@ function run_SCDC( n_iter_nc::Int64 = 1, damp_nc::Float64 = 0.0, callback::Function=(x...) -> nothing, - rng::AbstractRNG=Xoshiro()) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} + rng::AbstractRNG=Xoshiro(1234)) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} # Initialize prior probabilities based on the expected mean number of source patients (γ) prior = zeros(n_states(model.Disease), model.N) - prior[1, :] .= (1 - γ) # x_i = S - prior[2, :] .= γ # x_i = I + @inbounds @fastmath for i in 1:model.N + prior[1, i] = 1 - γ # x_i = S + prior[2, i] = γ # x_i = I + end # Format nodes for inference nodes = nodes_formatting(model, obsprob) @@ -255,15 +273,13 @@ function run_SCDC( M = TransMat(model.T, model.Disease) ρ = FBm(model.T, model.Disease) sumargexp = SumM(model.T) - updmess = Updmess(model.T, model.Disease) newmess = Message(0, 0, model.T) - newmarg = Marginal(0, model.T, model.Disease) ε = 0.0 # Iteratively update cavity messages until convergence or maximum iterations reached for iter = 1:maxiter - ε = update_cavities!(nodes, sumargexp, M, ρ, prior, model.T, updmess, newmess, newmarg, damp, μ_cutoff, model.Disease, rng) + update_cavities!(ε, nodes, sumargexp, M, ρ, prior, model.T, newmess, damp, μ_cutoff, model.Disease, rng) callback(nodes, iter, ε) # Check for convergence @@ -279,30 +295,37 @@ function run_SCDC( avg_mess = [[Message(node.i, j, model.T; zero_mess=true) for j in node.∂] for node in nodes] - for iter in 1:n_iter_nc + for _ in 1:n_iter_nc # compute average messages for inode in shuffle(rng, nodes) sumargexp = compute_sumargexp!(inode, nodes, sumargexp) for (jindex, j) in enumerate(inode.∂) iindex = nodes[j].∂_idx[inode.i] - M, ρ = compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, model.T, model.Disease) - clear!(updmess, newmess) - updmess.lognumm .= log.(ρ.fwm) .+ log.(ρ.bwm) - updmess.signμ .= sign.(ρ.bwm[1, 2:end] - ρ.bwm[2, 2:end]) - updmess.lognumμ .= log.(ρ.fwm[1, 1:end-1]) .+ log.(M[1, 1, :]) .+ log.(abs.(ρ.bwm[1, 2:end] - ρ.bwm[2, 2:end])) - updmess.logZ .= log.(dropdims(sum(ρ.fwm .* ρ.bwm, dims=1), dims=1)) - - newmess.m .= exp.(updmess.lognumm[2, :] .- updmess.logZ) - newmess.μ .= max.(updmess.signμ .* exp.(updmess.lognumμ .- updmess.logZ[1:end-1]), μ_cutoff) - - newmess.m .= nodes[j].cavities[iindex].m.*damp_nc .+ newmess.m.*(1 - damp_nc) - newmess.μ .= nodes[j].cavities[iindex].μ.*damp_nc .+ newmess.μ.*(1 - damp_nc) - - avg_mess[j][iindex].m .+= newmess.m - avg_mess[j][iindex].μ .+= newmess.μ - - nodes[j].cavities[iindex].m .= newmess.m - nodes[j].cavities[iindex].μ .= newmess.μ + compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, model.T, model.Disease) + #clear!(newmess) + @inbounds @fastmath for t in 1:model.T + norm = 0.0 + @inbounds @fastmath for x in 1:n_states(model.Disease) + norm += ρ.fwm[x,t] * ρ.bwm[x,t] + end + newmess.m[t] = ρ.fwm[2,t] * ρ.bwm[2,t] / norm + newmess.μ[t] = max(ρ.fwm[1,t] * M[1,1,t] * (ρ.bwm[1,t+1] - ρ.bwm[2,t+1]) / norm, μ_cutoff) + newmess.m[t] = nodes[j].cavities[iindex].m[t]*damp_nc + newmess.m[t]*(1 - damp_nc) + newmess.μ[t] = nodes[j].cavities[iindex].μ[t]*damp_nc + newmess.μ[t]*(1 - damp_nc) + avg_mess[j][iindex].m[t] += newmess.m[t] + avg_mess[j][iindex].μ[t] += newmess.μ[t] + nodes[j].cavities[iindex].m[t] = newmess.m[t] + nodes[j].cavities[iindex].μ[t] = newmess.μ[t] + end + # t = T+1 + norm = 0.0 + @inbounds @fastmath for x in 1:n_states(model.Disease) + norm += ρ.fwm[x,model.T+1] * ρ.bwm[x,model.T+1] + end + newmess.m[model.T+1] = ρ.fwm[2,model.T+1] * ρ.bwm[2,model.T+1] / norm + newmess.m[model.T+1] = nodes[j].cavities[iindex].m[model.T+1]*damp_nc + newmess.m[model.T+1]*(1 - damp_nc) + avg_mess[j][iindex].m[model.T+1] += newmess.m[model.T+1] + nodes[j].cavities[iindex].m[model.T+1] = newmess.m[model.T+1] end end end @@ -318,21 +341,9 @@ function run_SCDC( end end end - - - # Update messages between nodes - for inode in shuffle(rng, nodes) - sumargexp = compute_sumargexp!(inode, nodes, sumargexp) - for (jindex, j) in enumerate(inode.∂) - iindex = nodes[j].∂_idx[inode.i] - _, ρ = compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, model.T, model.Disease) - nodes[j].ρs[iindex].fwm .= ρ.fwm - nodes[j].ρs[iindex].bwm .= ρ.bwm - end - end # Compute final marginal probabilities - compute_marginals!(nodes, sumargexp, M, ρ, model.T, prior, updmess, newmarg, μ_cutoff, model.Disease, rng) + compute_marginals!(nodes, sumargexp, M, ρ, model.T, prior, model.Disease, rng) return nodes end @@ -350,7 +361,7 @@ end n_iter_nc::Int64 = 1, damp_nc::Float64 = 0.0, callback::Function=(x...) -> nothing, - rng::AbstractRNG=Xoshiro()) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} + rng::AbstractRNG=Xoshiro(1234)) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} Runs the Small Coupling Dynamic Cavity (SCDC) inference algorithm. @@ -384,7 +395,7 @@ function run_SCDC( n_iter_nc::Int64 = 1, damp_nc::Float64 = 0.0, callback::Function=(x...) -> nothing, - rng::AbstractRNG=Xoshiro()) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} + rng::AbstractRNG=Xoshiro(1234)) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} # Debugging if length(maxiter) != length(damp) @@ -393,8 +404,10 @@ function run_SCDC( # Initialize prior probabilities based on the expected mean number of source patients (γ) prior = zeros(n_states(model.Disease), model.N) - prior[1, :] .= (1 - γ) # x_i = S - prior[2, :] .= γ # x_i = I + @inbounds @fastmath for i in 1:model.N + prior[1, i] = (1 - γ) # x_i = S + prior[2, i] = γ # x_i = I + end # Format nodes for inference nodes = nodes_formatting(model, obsprob) @@ -403,9 +416,7 @@ function run_SCDC( M = TransMat(model.T, model.Disease) ρ = FBm(model.T, model.Disease) sumargexp = SumM(model.T) - updmess = Updmess(model.T, model.Disease) newmess = Message(0, 0, model.T) - newmarg = Marginal(0, model.T, model.Disease) ε = 0.0 @@ -414,8 +425,7 @@ function run_SCDC( check_convergence = false for (mi, d) in Iterators.zip(maxiter, damp) for _ in 1:mi - ε = update_cavities!(nodes, sumargexp, M, ρ, prior, model.T, updmess, newmess, newmarg, d, μ_cutoff, model.Disease, rng) - + update_cavities!(ε, nodes, sumargexp, M, ρ, prior, model.T, newmess, d, μ_cutoff, model.Disease, rng) iter += 1 callback(nodes, iter, ε) @@ -438,30 +448,37 @@ function run_SCDC( avg_mess = [[Message(node.i, j, model.T; zero_mess=true) for j in node.∂] for node in nodes] - for iter in 1:n_iter_nc + for _ in 1:n_iter_nc # compute average messages for inode in shuffle(rng, nodes) sumargexp = compute_sumargexp!(inode, nodes, sumargexp) for (jindex, j) in enumerate(inode.∂) iindex = nodes[j].∂_idx[inode.i] - M, ρ = compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, model.T, model.Disease) - clear!(updmess, newmess) - updmess.lognumm .= log.(ρ.fwm) .+ log.(ρ.bwm) - updmess.signμ .= sign.(ρ.bwm[1, 2:end] - ρ.bwm[2, 2:end]) - updmess.lognumμ .= log.(ρ.fwm[1, 1:end-1]) .+ log.(M[1, 1, :]) .+ log.(abs.(ρ.bwm[1, 2:end] - ρ.bwm[2, 2:end])) - updmess.logZ .= log.(dropdims(sum(ρ.fwm .* ρ.bwm, dims=1), dims=1)) - - newmess.m .= exp.(updmess.lognumm[2, :] .- updmess.logZ) - newmess.μ .= max.(updmess.signμ .* exp.(updmess.lognumμ .- updmess.logZ[1:end-1]), μ_cutoff) - - newmess.m .= nodes[j].cavities[iindex].m.*damp_nc .+ newmess.m.*(1 - damp_nc) - newmess.μ .= nodes[j].cavities[iindex].μ.*damp_nc .+ newmess.μ.*(1 - damp_nc) - - avg_mess[j][iindex].m .+= newmess.m - avg_mess[j][iindex].μ .+= newmess.μ - - nodes[j].cavities[iindex].m .= newmess.m - nodes[j].cavities[iindex].μ .= newmess.μ + compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, model.T, model.Disease) + #clear!(newmess) + @inbounds @fastmath for t in 1:model.T + norm = 0.0 + @inbounds @fastmath for x in 1:n_states(model.Disease) + norm += ρ.fwm[x,t] * ρ.bwm[x,t] + end + newmess.m[t] = ρ.fwm[2,t] * ρ.bwm[2,t] / norm + newmess.μ[t] = max(ρ.fwm[1,t] * M[1,1,t] * (ρ.bwm[1,t+1] - ρ.bwm[2,t+1]) / norm, μ_cutoff) + newmess.m[t] = nodes[j].cavities[iindex].m[t]*damp_nc + newmess.m[t]*(1 - damp_nc) + newmess.μ[t] = nodes[j].cavities[iindex].μ[t]*damp_nc + newmess.μ[t]*(1 - damp_nc) + avg_mess[j][iindex].m[t] += newmess.m[t] + avg_mess[j][iindex].μ[t] += newmess.μ[t] + nodes[j].cavities[iindex].m[t] = newmess.m[t] + nodes[j].cavities[iindex].μ[t] = newmess.μ[t] + end + # t = T+1 + norm = 0.0 + @inbounds @fastmath for x in 1:n_states(model.Disease) + norm += ρ.fwm[x,model.T+1] * ρ.bwm[x,model.T+1] + end + newmess.m[model.T+1] = ρ.fwm[2,model.T+1] * ρ.bwm[2,model.T+1] / norm + newmess.m[model.T+1] = nodes[j].cavities[iindex].m[model.T+1]*damp_nc + newmess.m[model.T+1]*(1 - damp_nc) + avg_mess[j][iindex].m[model.T+1] += newmess.m[model.T+1] + nodes[j].cavities[iindex].m[model.T+1] = newmess.m[model.T+1] end end end @@ -477,21 +494,284 @@ function run_SCDC( end end end - - # Update messages between nodes - for inode in shuffle(rng, nodes) - sumargexp = compute_sumargexp!(inode, nodes, sumargexp) - for (jindex, j) in enumerate(inode.∂) - iindex = nodes[j].∂_idx[inode.i] - _, ρ = compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, model.T, model.Disease) - nodes[j].ρs[iindex].fwm .= ρ.fwm - nodes[j].ρs[iindex].bwm .= ρ.bwm + # Compute final marginal probabilities + compute_marginals!(nodes, sumargexp, M, ρ, model.T, prior, model.Disease, rng) + + return nodes +end + +""" + run_SCDC( + nodes::Vector{Node{TI,TG}}, + model::EpidemicModel{TI,TG}, + γ::Float64, + maxiter::Vector{Int64}, + epsconv::Float64, + damp::Vector{Float64}; + μ_cutoff::Float64 = -Inf, + n_iter_nc::Int64 = 1, + damp_nc::Float64 = 0.0, + callback::Function=(x...) -> nothing, + rng::AbstractRNG=Xoshiro(1234)) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} + +Run the SCDC algorithm for epidemic modeling. The algorithm resumes the message-passing iterations from the current state of the nodes. + +This function performs SCDC inference on the specified epidemic model, using the provided evidence (likelihood) probability function, and other parameters such as the probability of being a patient zero, maximum number of iterations, convergence threshold, damping factor, etc. It iteratively updates cavity messages until convergence or until the maximum number of iterations is reached. It implements a dumping schedule for the damping factor, where the dumping factor is changed after a certain number of iterations, specified by the `maxiter` and `damp` arguments. + +# Arguments +- `nodes::Vector{Node{TI,TG}}`: Vector of nodes in the epidemic model. +- `model::EpidemicModel{TI,TG}`: The epidemic model to be used. +- `γ::Float64`: A parameter for the algorithm (e.g., infection rate). +- `maxiter::Vector{Int64}`: Maximum number of iterations for the algorithm. +- `epsconv::Float64`: Convergence threshold for the algorithm. +- `damp::Vector{Float64}`: Damping factors for the algorithm. + +# Keyword Arguments +- `μ_cutoff::Float64`: Cutoff value for some parameter μ (default is -Inf). +- `n_iter_nc::Int64`: Number of iterations for non-converging cases (default is 1). +- `damp_nc::Float64`: Damping factor for non-converging cases (default is 0.0). +- `callback::Function`: Callback function to be called during iterations (default does nothing). +- `rng::AbstractRNG`: Random number generator (default is Xoshiro). + +# Returns +- `Vector{Node{TI,TG}}`: The updated vector of nodes after running the algorithm. +""" +function run_SCDC( + nodes::Vector{Node{TI,TG}}, + model::EpidemicModel{TI,TG}, + γ::Float64, + maxiter::Int64, + epsconv::Float64, + damp::Float64; + μ_cutoff::Float64 = -Inf, + n_iter_nc::Int64 = 1, + damp_nc::Float64 = 0.0, + callback::Function=(x...) -> nothing, + rng::AbstractRNG=Xoshiro(1234)) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} + + # Initialize prior probabilities based on the expected mean number of source patients (γ) + prior = zeros(n_states(model.Disease), model.N) + @inbounds @fastmath for i in 1:model.N + prior[1, i] = (1 - γ) # x_i = S + prior[2, i] = γ # x_i = I + end + + # Initialize message objects + M = TransMat(model.T, model.Disease) + ρ = FBm(model.T, model.Disease) + sumargexp = SumM(model.T) + newmess = Message(0, 0, model.T) + + ε = 0.0 + + # Iteratively update cavity messages until convergence or maximum iterations reached + for iter = 1:maxiter + update_cavities!(ε, nodes, sumargexp, M, ρ, prior, model.T, newmess, damp, μ_cutoff, model.Disease, rng) + callback(nodes, iter, ε) + + # Check for convergence + if ε < epsconv + println("Converged after $iter iterations") + break + end + end + + # Check if convergence not achieved + if ε > epsconv + println("NOT converged after $maxiter iterations") + + avg_mess = [[Message(node.i, j, model.T; zero_mess=true) for j in node.∂] for node in nodes] + + for _ in 1:n_iter_nc + # compute average messages + for inode in shuffle(rng, nodes) + sumargexp = compute_sumargexp!(inode, nodes, sumargexp) + for (jindex, j) in enumerate(inode.∂) + iindex = nodes[j].∂_idx[inode.i] + compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, model.T, model.Disease) + #clear!(newmess) + @inbounds @fastmath for t in 1:model.T + norm = 0.0 + @inbounds @fastmath for x in 1:n_states(model.Disease) + norm += ρ.fwm[x,t] * ρ.bwm[x,t] + end + newmess.m[t] = ρ.fwm[2,t] * ρ.bwm[2,t] / norm + newmess.μ[t] = max(ρ.fwm[1,t] * M[1,1,t] * (ρ.bwm[1,t+1] - ρ.bwm[2,t+1]) / norm, μ_cutoff) + newmess.m[t] = nodes[j].cavities[iindex].m[t]*damp_nc + newmess.m[t]*(1 - damp_nc) + newmess.μ[t] = nodes[j].cavities[iindex].μ[t]*damp_nc + newmess.μ[t]*(1 - damp_nc) + avg_mess[j][iindex].m[t] += newmess.m[t] + avg_mess[j][iindex].μ[t] += newmess.μ[t] + nodes[j].cavities[iindex].m[t] = newmess.m[t] + nodes[j].cavities[iindex].μ[t] = newmess.μ[t] + end + # t = T+1 + norm = 0.0 + @inbounds @fastmath for x in 1:n_states(model.Disease) + norm += ρ.fwm[x,model.T+1] * ρ.bwm[x,model.T+1] + end + newmess.m[model.T+1] = ρ.fwm[2,model.T+1] * ρ.bwm[2,model.T+1] / norm + newmess.m[model.T+1] = nodes[j].cavities[iindex].m[model.T+1]*damp_nc + newmess.m[model.T+1]*(1 - damp_nc) + avg_mess[j][iindex].m[model.T+1] += newmess.m[model.T+1] + nodes[j].cavities[iindex].m[model.T+1] = newmess.m[model.T+1] + end + end + end + + if n_iter_nc != 0 + # compute average messages + for inode in nodes + for (_, j) in enumerate(inode.∂) + iindex = nodes[j].∂_idx[inode.i] + nodes[j].cavities[iindex].m .= avg_mess[j][iindex].m ./ n_iter_nc + nodes[j].cavities[iindex].μ .= avg_mess[j][iindex].μ ./ n_iter_nc + end + end + end + end + + # Compute final marginal probabilities + compute_marginals!(nodes, sumargexp, M, ρ, model.T, prior, model.Disease, rng) + + return nodes +end + + +""" + run_SCDC(nodes::Vector{Node{TI,TG}}, model::EpidemicModel{TI,TG}, γ::Float64, maxiter::Vector{Int64}, epsconv::Float64, damp::Vector{Float64}; μ_cutoff::Float64 = -Inf, n_iter_nc::Int64 = 1, damp_nc::Float64 = 0.0, callback::Function=(x...) -> nothing, rng::AbstractRNG=Xoshiro(1234)) + +Run the SCDC algorithm for epidemic modeling. The algorithm resumes the message-passing iterations from the current state of the nodes. + +This function performs SCDC inference on the specified epidemic model, using the provided evidence (likelihood) probability function, and other parameters such as the probability of being a patient zero, maximum number of iterations, convergence threshold, damping factor, etc. It iteratively updates cavity messages until convergence or until the maximum number of iterations is reached. It implements a dumping schedule for the damping factor, where the dumping factor is changed after a certain number of iterations, specified by the `maxiter` and `damp` arguments. + +# Arguments +- `nodes::Vector{Node{TI,TG}}`: Vector of nodes in the epidemic model. +- `model::EpidemicModel{TI,TG}`: The epidemic model to be used. +- `γ::Float64`: A parameter for the algorithm (e.g., infection rate). +- `maxiter::Vector{Int64}`: Maximum number of iterations for the algorithm. +- `epsconv::Float64`: Convergence threshold for the algorithm. +- `damp::Vector{Float64}`: Damping factors for the algorithm. + +# Keyword Arguments +- `μ_cutoff::Float64`: Cutoff value for some parameter μ (default is -Inf). +- `n_iter_nc::Int64`: Number of iterations for non-converging cases (default is 1). +- `damp_nc::Float64`: Damping factor for non-converging cases (default is 0.0). +- `callback::Function`: Callback function to be called during iterations (default does nothing). +- `rng::AbstractRNG`: Random number generator (default is Xoshiro). + +# Returns +- `Vector{Node{TI,TG}}`: The updated vector of nodes after running the algorithm. +""" +function run_SCDC( + nodes::Vector{Node{TI,TG}}, + model::EpidemicModel{TI,TG}, + γ::Float64, + maxiter::Vector{Int64}, + epsconv::Float64, + damp::Vector{Float64}; + μ_cutoff::Float64 = -Inf, + n_iter_nc::Int64 = 1, + damp_nc::Float64 = 0.0, + callback::Function=(x...) -> nothing, + rng::AbstractRNG=Xoshiro(1234)) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} + + # Debugging + if length(maxiter) != length(damp) + throw(DomainError("Length of maxiter and damp vectors must be the same!")) + end + + # Initialize prior probabilities based on the expected mean number of source patients (γ) + prior = zeros(n_states(model.Disease), model.N) + @inbounds @fastmath for i in 1:model.N + prior[1, i] = (1 - γ) # x_i = S + prior[2, i] = γ # x_i = I + end + + # Initialize message objects + M = TransMat(model.T, model.Disease) + ρ = FBm(model.T, model.Disease) + sumargexp = SumM(model.T) + newmess = Message(0, 0, model.T) + + ε = 0.0 + + # Iteratively update cavity messages until convergence or maximum iterations reached + iter = 0 + check_convergence = false + for (mi, d) in Iterators.zip(maxiter, damp) + for _ in 1:mi + update_cavities!(ε, nodes, sumargexp, M, ρ, prior, model.T, newmess, damp, μ_cutoff, model.Disease, rng) + iter += 1 + callback(nodes, iter, ε) + + # Check for convergence + if ε < epsconv + println("Converged after $iter iterations") + check_convergence = true + break + end + end + + if check_convergence + break + end + end + + # Check if convergence not achieved + if ε > epsconv + println("NOT converged after $maxiter iterations") + + avg_mess = [[Message(node.i, j, model.T; zero_mess=true) for j in node.∂] for node in nodes] + + for _ in 1:n_iter_nc + # compute average messages + for inode in shuffle(rng, nodes) + sumargexp = compute_sumargexp!(inode, nodes, sumargexp) + for (jindex, j) in enumerate(inode.∂) + iindex = nodes[j].∂_idx[inode.i] + compute_ρ!(inode, iindex, nodes[j], jindex, sumargexp, M, ρ, prior, model.T, model.Disease) + #clear!(newmess) + @inbounds @fastmath for t in 1:model.T + norm = 0.0 + @inbounds @fastmath for x in 1:n_states(model.Disease) + norm += ρ.fwm[x,t] * ρ.bwm[x,t] + end + newmess.m[t] = ρ.fwm[2,t] * ρ.bwm[2,t] / norm + newmess.μ[t] = max(ρ.fwm[1,t] * M[1,1,t] * (ρ.bwm[1,t+1] - ρ.bwm[2,t+1]) / norm, μ_cutoff) + newmess.m[t] = nodes[j].cavities[iindex].m[t]*damp_nc + newmess.m[t]*(1 - damp_nc) + newmess.μ[t] = nodes[j].cavities[iindex].μ[t]*damp_nc + newmess.μ[t]*(1 - damp_nc) + avg_mess[j][iindex].m[t] += newmess.m[t] + avg_mess[j][iindex].μ[t] += newmess.μ[t] + nodes[j].cavities[iindex].m[t] = newmess.m[t] + nodes[j].cavities[iindex].μ[t] = newmess.μ[t] + end + # t = T+1 + norm = 0.0 + @inbounds @fastmath for x in 1:n_states(model.Disease) + norm += ρ.fwm[x,model.T+1] * ρ.bwm[x,model.T+1] + end + newmess.m[model.T+1] = ρ.fwm[2,model.T+1] * ρ.bwm[2,model.T+1] / norm + newmess.m[model.T+1] = nodes[j].cavities[iindex].m[model.T+1]*damp_nc + newmess.m[model.T+1]*(1 - damp_nc) + avg_mess[j][iindex].m[model.T+1] += newmess.m[model.T+1] + nodes[j].cavities[iindex].m[model.T+1] = newmess.m[model.T+1] + end + end + end + + if n_iter_nc != 0 + # compute average messages + for inode in nodes + for (_, j) in enumerate(inode.∂) + iindex = nodes[j].∂_idx[inode.i] + nodes[j].cavities[iindex].m .= avg_mess[j][iindex].m ./ n_iter_nc + nodes[j].cavities[iindex].μ .= avg_mess[j][iindex].μ ./ n_iter_nc + end + end end end # Compute final marginal probabilities - compute_marginals!(nodes, sumargexp, M, ρ, model.T, prior, updmess, newmarg, μ_cutoff, model.Disease, rng) + compute_marginals!(nodes, sumargexp, M, ρ, model.T, prior, model.Disease, rng) return nodes end \ No newline at end of file diff --git a/src/models/SI.jl b/src/models/SI.jl index e88124a..6f3e9da 100644 --- a/src/models/SI.jl +++ b/src/models/SI.jl @@ -51,8 +51,10 @@ function nodes_formatting( for i in 1:model.N obs = ones(2, model.T + 1) - obs[1, :] = [obsprob(Ob, 0) for Ob in model.obsmat[i, :]] - obs[2, :] = [obsprob(Ob, 1) for Ob in model.obsmat[i, :]] + @inbounds @fastmath for t in 1:model.T+1 + obs[1, t] = obsprob(model.obsmat[i,t], 0) + obs[2, t] = obsprob(model.obsmat[i,t], 1) + end ∂ = neighbors(model.G, i) @@ -72,8 +74,10 @@ function nodes_formatting( for i in 1:model.N obs = ones(2, model.T + 1) - obs[1, :] = [obsprob(Ob, 0) for Ob in model.obsmat[i, :]] - obs[2, :] = [obsprob(Ob, 1) for Ob in model.obsmat[i, :]] + @inbounds @fastmath for t in 1:model.T+1 + obs[1, t] = obsprob(model.obsmat[i,t], 0) + obs[2, t] = obsprob(model.obsmat[i,t], 1) + end ∂ = Vector{Int}() @@ -99,9 +103,11 @@ function fill_transmat_cav!( sumargexp::SumM, infectionmodel::SI) where {TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - M[1, 1, :] .= (exp.(sumargexp.summ .- inode.cavities[jindex].m[1:end-1] .* inode.νs[jindex])).*(1 .- infectionmodel.εᵢᵗ[inode.i, :]) .* inode.obs[1, 1:end-1] - M[1, 2, :] .= (1 .- exp.(sumargexp.summ .- inode.cavities[jindex].m[1:end-1] .* inode.νs[jindex]).*(1 .- infectionmodel.εᵢᵗ[inode.i, :])) .* inode.obs[1, 1:end-1] - M[2, 2, :] .= exp.(sumargexp.sumμ .- inode.cavities[jindex].μ .* jnode.νs[iindex]) .* inode.obs[2, 1:end-1] + @inbounds @fastmath for t in 1:inode.model.T + M[1, 1, t] = (exp(sumargexp.summ[t] - inode.cavities[jindex].m[t] * inode.νs[jindex][t]) * (1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[1, 2, t] = (1 - exp(sumargexp.summ[t] - inode.cavities[jindex].m[t] * inode.νs[jindex][t]) * (1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[2, 2, t] = exp(sumargexp.sumμ[t] - inode.cavities[jindex].μ[t] * jnode.νs[iindex][t]) * inode.obs[2, t] + end end function fill_transmat_marg!( @@ -110,9 +116,11 @@ function fill_transmat_marg!( sumargexp::SumM, infectionmodel::SI) where {TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - M[1, 1, :] .= (exp.(sumargexp.summ)).*(1 .- infectionmodel.εᵢᵗ[inode.i, :]) .* inode.obs[1, 1:end-1] - M[1, 2, :] .= (1 .- exp.(sumargexp.summ).*(1 .- infectionmodel.εᵢᵗ[inode.i, :])) .* inode.obs[1, 1:end-1] - M[2, 2, :] .= exp.(sumargexp.sumμ) .* inode.obs[2, 1:end-1] + @inbounds @fastmath for t in 1:inode.model.T + M[1, 1, t] = (exp(sumargexp.summ[t]) * (1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[1, 2, t] = (1 - exp(sumargexp.summ[t]) * (1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[2, 2, t] = exp(sumargexp.sumμ[t]) * inode.obs[2, t] + end end """ diff --git a/src/models/SIR.jl b/src/models/SIR.jl index 41400ed..e1ca650 100644 --- a/src/models/SIR.jl +++ b/src/models/SIR.jl @@ -62,9 +62,11 @@ function nodes_formatting( for i in 1:model.N obs = ones(3, model.T + 1) - obs[1, :] = [obsprob(Ob, 0) for Ob in model.obsmat[i, :]] - obs[2, :] = [obsprob(Ob, 1) for Ob in model.obsmat[i, :]] - obs[3, :] = [obsprob(Ob, 2) for Ob in model.obsmat[i, :]] + @inbounds @fastmath for t in 1:model.T+1 + obs[1, t] = obsprob(model.obsmat[i,t], 0) + obs[2, t] = obsprob(model.obsmat[i,t], 1) + obs[3, t] = obsprob(model.obsmat[i,t], 2) + end ∂ = neighbors(model.G, i) @@ -85,9 +87,11 @@ function nodes_formatting( for i in 1:model.N obs = ones(3, model.T + 1) - obs[1, :] = [obsprob(Ob, 0) for Ob in model.obsmat[i, :]] - obs[2, :] = [obsprob(Ob, 1) for Ob in model.obsmat[i, :]] - obs[3, :] = [obsprob(Ob, 2) for Ob in model.obsmat[i, :]] + @inbounds @fastmath for t in 1:model.T+1 + obs[1, t] = obsprob(model.obsmat[i, t], 0) + obs[2, t] = obsprob(model.obsmat[i, t], 1) + obs[3, t] = obsprob(model.obsmat[i, t], 2) + end ∂ = Vector{Int}() @@ -113,11 +117,13 @@ function fill_transmat_cav!( sumargexp::SumM, infectionmodel::SIR) where {TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - M[1, 1, :] .= (exp.(sumargexp.summ .- inode.cavities[jindex].m[1:end-1] .* inode.νs[jindex])).*(1 .- infectionmodel.εᵢᵗ[inode.i, :]) .* inode.obs[1, 1:end-1] - M[1, 2, :] .= (1 .- exp.(sumargexp.summ .- inode.cavities[jindex].m[1:end-1] .* inode.νs[jindex]).*(1 .- infectionmodel.εᵢᵗ[inode.i, :])) .* inode.obs[1, 1:end-1] - M[2, 2, :] .= (1 .- infectionmodel.rᵢᵗ[inode.i, :]) .* exp.(sumargexp.sumμ .- inode.cavities[jindex].μ .* jnode.νs[iindex]) .* inode.obs[2, 1:end-1] - M[2, 3, :] .= infectionmodel.rᵢᵗ[inode.i, :] .* exp.(sumargexp.sumμ .- inode.cavities[jindex].μ .* jnode.νs[iindex]) .* inode.obs[2, 1:end-1] - M[3, 3, :] .= inode.obs[3, 1:end-1] + @inbounds @fastmath for t in 1:inode.model.T + M[1, 1, t] = exp(sumargexp.summ[t] - inode.cavities[jindex].m[t] * inode.νs[jindex][t])*(1 - infectionmodel.εᵢᵗ[inode.i, t]) * inode.obs[1, t] + M[1, 2, t] = (1 - exp(sumargexp.summ[t] - inode.cavities[jindex].m[t] * inode.νs[jindex][t])*(1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[2, 2, t] = (1 - infectionmodel.rᵢᵗ[inode.i, t]) * exp(sumargexp.sumμ[t] - inode.cavities[jindex].μ[t] * jnode.νs[iindex][t]) * inode.obs[2, t] + M[2, 3, t] = infectionmodel.rᵢᵗ[inode.i, t] * exp(sumargexp.sumμ[t] - inode.cavities[jindex].μ[t] * jnode.νs[iindex][t]) * inode.obs[2, t] + M[3, 3, t] = inode.obs[3, t] + end end function fill_transmat_marg!( @@ -126,11 +132,13 @@ function fill_transmat_marg!( sumargexp::SumM, infectionmodel::SIR) where {TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - M[1, 1, :] .= (exp.(sumargexp.summ)).*(1 .- infectionmodel.εᵢᵗ[inode.i, :]) .* inode.obs[1, 1:end-1] - M[1, 2, :] .= (1 .- exp.(sumargexp.summ).*(1 .- infectionmodel.εᵢᵗ[inode.i, :])) .* inode.obs[1, 1:end-1] - M[2, 2, :] .= (1 .- infectionmodel.rᵢᵗ[inode.i, :]) .* exp.(sumargexp.sumμ) .* inode.obs[2, 1:end-1] - M[2, 3, :] .= infectionmodel.rᵢᵗ[inode.i, :] .* exp.(sumargexp.sumμ) .* inode.obs[2, 1:end-1] - M[3, 3, :] .= inode.obs[3, 1:end-1] + @inbounds @fastmath for t in 1:inode.model.T + M[1, 1, t] = (exp(sumargexp.summ[t])*(1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[1, 2, t] = (1 - exp(sumargexp.summ[t])*(1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[2, 2, t] = (1 - infectionmodel.rᵢᵗ[inode.i, t]) * exp(sumargexp.sumμ[t]) * inode.obs[2, t] + M[2, 3, t] = infectionmodel.rᵢᵗ[inode.i, t] * exp(sumargexp.sumμ[t]) * inode.obs[2, t] + M[3, 3, t] = inode.obs[3, t] + end end """ diff --git a/src/models/SIRS.jl b/src/models/SIRS.jl index d100cef..084325f 100644 --- a/src/models/SIRS.jl +++ b/src/models/SIRS.jl @@ -69,9 +69,11 @@ function nodes_formatting( for i in 1:model.N obs = ones(3, model.T + 1) - obs[1, :] = [obsprob(Ob, 0) for Ob in model.obsmat[i, :]] - obs[2, :] = [obsprob(Ob, 1) for Ob in model.obsmat[i, :]] - obs[3, :] = [obsprob(Ob, 2) for Ob in model.obsmat[i, :]] + @inbounds @fastmath for t in 1:model.T+1 + obs[1, t] = obsprob(model.obsmat[i, t], 0) + obs[2, t] = obsprob(model.obsmat[i, t], 1) + obs[3, t] = obsprob(model.obsmat[i, t], 2) + end ∂ = neighbors(model.G, i) @@ -92,9 +94,11 @@ function nodes_formatting( for i in 1:model.N obs = ones(3, model.T + 1) - obs[1, :] = [obsprob(Ob, 0) for Ob in model.obsmat[i, :]] - obs[2, :] = [obsprob(Ob, 1) for Ob in model.obsmat[i, :]] - obs[3, :] = [obsprob(Ob, 2) for Ob in model.obsmat[i, :]] + @inbounds @fastmath for t in 1:model.T+1 + obs[1, t] = obsprob(model.obsmat[i, t], 0) + obs[2, t] = obsprob(model.obsmat[i, t], 1) + obs[3, t] = obsprob(model.obsmat[i, t], 2) + end ∂ = Vector{Int}() @@ -120,12 +124,14 @@ function fill_transmat_cav!( sumargexp::SumM, infectionmodel::SIRS) where {TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - M[1, 1, :] .= (exp.(sumargexp.summ .- inode.cavities[jindex].m[1:end-1] .* inode.νs[jindex])).*(1 .- infectionmodel.εᵢᵗ[inode.i, :]) .* inode.obs[1, 1:end-1] - M[1, 2, :] .= (1 .- exp.(sumargexp.summ .- inode.cavities[jindex].m[1:end-1] .* inode.νs[jindex]).*(1 .- infectionmodel.εᵢᵗ[inode.i, :])) .* inode.obs[1, 1:end-1] - M[2, 2, :] .= (1 .- infectionmodel.rᵢᵗ[inode.i, :]) .* exp.(sumargexp.sumμ .- inode.cavities[jindex].μ .* jnode.νs[iindex]) .* inode.obs[2, 1:end-1] - M[2, 3, :] .= infectionmodel.rᵢᵗ[inode.i, :] .* exp.(sumargexp.sumμ .- inode.cavities[jindex].μ .* jnode.νs[iindex]) .* inode.obs[2, 1:end-1] - M[3, 1, :] .= infectionmodel.σᵢᵗ[inode.i, :] .* inode.obs[3, 1:end-1] - M[3, 3, :] .= (1 .- infectionmodel.σᵢᵗ[inode.i, :]) .* inode.obs[3, 1:end-1] + @inbounds @fastmath for t in 1:inode.model.T + M[1, 1, t] = exp(sumargexp.summ[t] - inode.cavities[jindex].m[t] * inode.νs[jindex][t])*(1 - infectionmodel.εᵢᵗ[inode.i, t]) * inode.obs[1, t] + M[1, 2, t] = (1 - exp(sumargexp.summ[t] - inode.cavities[jindex].m[t] * inode.νs[jindex][t])*(1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[2, 2, t] = (1 - infectionmodel.rᵢᵗ[inode.i, t]) * exp(sumargexp.sumμ[t] - inode.cavities[jindex].μ[t] * jnode.νs[iindex][t]) * inode.obs[2, t] + M[2, 3, t] = infectionmodel.rᵢᵗ[inode.i, t] * exp(sumargexp.sumμ[t] - inode.cavities[jindex].μ[t] * jnode.νs[iindex][t]) * inode.obs[2, t] + M[3, 1, t] = infectionmodel.σᵢᵗ[inode.i, t] * inode.obs[3, t] + M[3, 3, t] = (1 - infectionmodel.σᵢᵗ[inode.i, t]) * inode.obs[3, t] + end end function fill_transmat_marg!( @@ -134,12 +140,14 @@ function fill_transmat_marg!( sumargexp::SumM, infectionmodel::SIRS) where {TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - M[1, 1, :] .= (exp.(sumargexp.summ)).*(1 .- infectionmodel.εᵢᵗ[inode.i, :]) .* inode.obs[1, 1:end-1] - M[1, 2, :] .= (1 .- exp.(sumargexp.summ).*(1 .- infectionmodel.εᵢᵗ[inode.i, :])) .* inode.obs[1, 1:end-1] - M[2, 2, :] .= (1 .- infectionmodel.rᵢᵗ[inode.i, :]) .* exp.(sumargexp.sumμ) .* inode.obs[2, 1:end-1] - M[2, 3, :] .= infectionmodel.rᵢᵗ[inode.i, :] .* exp.(sumargexp.sumμ) .* inode.obs[2, 1:end-1] - M[3, 1, :] .= infectionmodel.σᵢᵗ[inode.i, :] .* inode.obs[3, 1:end-1] - M[3, 3, :] .= (1 .- infectionmodel.σᵢᵗ[inode.i, :]) .* inode.obs[3, 1:end-1] + @inbounds @fastmath for t in 1:inode.model.T + M[1, 1, t] = exp(sumargexp.summ[t])*(1 - infectionmodel.εᵢᵗ[inode.i, t]) * inode.obs[1, t] + M[1, 2, t] = (1 - exp(sumargexp.summ[t])*(1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[2, 2, t] = (1 - infectionmodel.rᵢᵗ[inode.i, t]) * exp(sumargexp.sumμ[t]) * inode.obs[2, t] + M[2, 3, t] = infectionmodel.rᵢᵗ[inode.i, t] * exp(sumargexp.sumμ[t]) * inode.obs[2, t] + M[3, 1, t] = infectionmodel.σᵢᵗ[inode.i, t] * inode.obs[3, t] + M[3, 3, t] = (1 - infectionmodel.σᵢᵗ[inode.i, t]) * inode.obs[3, t] + end end """ diff --git a/src/models/SIS.jl b/src/models/SIS.jl index d7e3189..b747f70 100644 --- a/src/models/SIS.jl +++ b/src/models/SIS.jl @@ -60,8 +60,10 @@ function nodes_formatting( for i in 1:model.N obs = ones(2, model.T + 1) - obs[1, :] = [obsprob(Ob, 0) for Ob in model.obsmat[i, :]] - obs[2, :] = [obsprob(Ob, 1) for Ob in model.obsmat[i, :]] + @inbounds @fastmath for t in 1:model.T+1 + obs[1, t] = obsprob(model.obsmat[i, t], 0) + obs[2, t] = obsprob(model.obsmat[i, t], 1) + end ∂ = neighbors(model.G, i) @@ -81,8 +83,10 @@ function nodes_formatting( for i in 1:model.N obs = ones(2, model.T + 1) - obs[1, :] = [obsprob(Ob, 0) for Ob in model.obsmat[i, :]] - obs[2, :] = [obsprob(Ob, 1) for Ob in model.obsmat[i, :]] + @inbounds @fastmath for t in 1:model.T+1 + obs[1, t] = obsprob(model.obsmat[i, t], 0) + obs[2, t] = obsprob(model.obsmat[i, t], 1) + end ∂ = Vector{Int}() @@ -108,10 +112,12 @@ function fill_transmat_cav!( sumargexp::SumM, infectionmodel::SIS) where {TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - M[1, 1, :] .= (exp.(sumargexp.summ .- inode.cavities[jindex].m[1:end-1] .* inode.νs[jindex])).*(1 .- infectionmodel.εᵢᵗ[inode.i, :]) .* inode.obs[1, 1:end-1] - M[1, 2, :] .= (1 .- exp.(sumargexp.summ .- inode.cavities[jindex].m[1:end-1] .* inode.νs[jindex]).*(1 .- infectionmodel.εᵢᵗ[inode.i, :])) .* inode.obs[1, 1:end-1] - M[2, 1, :] .= infectionmodel.rᵢᵗ[inode.i, :] .* exp.(sumargexp.sumμ .- inode.cavities[jindex].μ .* jnode.νs[iindex]) .* inode.obs[2, 1:end-1] - M[2, 2, :] .= (1 .- infectionmodel.rᵢᵗ[inode.i, :]) .* exp.(sumargexp.sumμ .- inode.cavities[jindex].μ .* jnode.νs[iindex]) .* inode.obs[2, 1:end-1] + @inbounds @fastmath for t in 1:inode.model.T + M[1, 1, t] = exp(sumargexp.summ[t] - inode.cavities[jindex].m[t] * inode.νs[jindex][t]) * (1 - infectionmodel.εᵢᵗ[inode.i, t]) * inode.obs[1, t] + M[1, 2, t] = (1 - exp(sumargexp.summ[t] - inode.cavities[jindex].m[t] * inode.νs[jindex][t]) * (1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[2, 1, t] = infectionmodel.rᵢᵗ[inode.i, t] * exp(sumargexp.sumμ[t] - inode.cavities[jindex].μ[t] * jnode.νs[iindex][t]) * inode.obs[2, t] + M[2, 2, t] = (1 - infectionmodel.rᵢᵗ[inode.i, t]) * exp(sumargexp.sumμ[t] - inode.cavities[jindex].μ[t] * jnode.νs[iindex][t]) * inode.obs[2, t] + end end function fill_transmat_marg!( @@ -120,10 +126,12 @@ function fill_transmat_marg!( sumargexp::SumM, infectionmodel::SIS) where {TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} - M[1, 1, :] .= (exp.(sumargexp.summ)).*(1 .- infectionmodel.εᵢᵗ[inode.i, :]) .* inode.obs[1, 1:end-1] - M[1, 2, :] .= (1 .- exp.(sumargexp.summ).*(1 .- infectionmodel.εᵢᵗ[inode.i, :])) .* inode.obs[1, 1:end-1] - M[2, 1, :] .= infectionmodel.rᵢᵗ[inode.i, :] .* exp.(sumargexp.sumμ) .* inode.obs[2, 1:end-1] - M[2, 2, :] .= (1 .- infectionmodel.rᵢᵗ[inode.i, :]) .* exp.(sumargexp.sumμ) .* inode.obs[2, 1:end-1] + @inbounds @fastmath for t in 1:inode.model.T + M[1, 1, t] = exp(sumargexp.summ[t]) * (1 - infectionmodel.εᵢᵗ[inode.i, t]) * inode.obs[1, t] + M[1, 2, t] = (1 - exp(sumargexp.summ[t]) * (1 - infectionmodel.εᵢᵗ[inode.i, t])) * inode.obs[1, t] + M[2, 1, t] = infectionmodel.rᵢᵗ[inode.i, t] * exp(sumargexp.sumμ[t]) * inode.obs[2, t] + M[2, 2, t] = (1 - infectionmodel.rᵢᵗ[inode.i, t]) * exp(sumargexp.sumμ[t]) * inode.obs[2, t] + end end """ diff --git a/src/types.jl b/src/types.jl index 17aa804..98fd51c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -14,7 +14,7 @@ struct FBm T::Int, infectionmodel::TI) where {TI <:InfectionModel} - new(ones(n_states(infectionmodel), T + 1), ones(n_states(infectionmodel), T + 1)) + new(zeros(n_states(infectionmodel), T + 1), zeros(n_states(infectionmodel), T + 1)) end end @@ -30,20 +30,6 @@ struct SumM end -struct Updmess - lognumm::Array{Float64,2} - lognumμ::Vector{Float64} - signμ::Vector{Float64} - logZ::Vector{Float64} - - function Updmess( - T::Int, - infectionmodel::TI) where {TI <:InfectionModel} - - new(zeros(n_states(infectionmodel), T + 1), zeros(T), ones(T), zeros(T + 1)) - end -end - """ Message diff --git a/src/utils.jl b/src/utils.jl index 90fd2cf..fff5c4a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,63 +2,22 @@ function clear!( M::Array{Float64,3}, ρ::FBm) fill!(M, 0.0) - fill!(ρ.fwm, 1.0) - fill!(ρ.bwm, 1.0) + fill!(ρ.fwm, 0.0) + fill!(ρ.bwm, 0.0) end -function clear!( - M::Array{Float64,3}, - ρ::FBm, - updmess::Updmess) - - fill!(M, 0.0) - fill!(ρ.fwm, 1.0) - fill!(ρ.bwm, 1.0) - fill!(updmess.lognumm, 0.0) - fill!(updmess.lognumμ, 0.0) - fill!(updmess.signμ, 1.0) - fill!(updmess.logZ, 0.0) -end - -function clear!( - updmess::Updmess, +function clear!( newmess::Message) - fill!(updmess.lognumm, 0.0) - fill!(updmess.lognumμ, 0.0) - fill!(updmess.signμ, 1.0) - fill!(updmess.logZ, 0.0) - fill!(newmess.m, 1.0) + fill!(newmess.m, 0.0) fill!(newmess.μ, 0.0) end -function clear!( - updmess::Updmess, - newmarg::Marginal) - - fill!(updmess.lognumm, 0.0) - fill!(updmess.lognumμ, 0.0) - fill!(updmess.signμ, 1.0) - fill!(updmess.logZ, 0.0) - fill!(newmarg.m, 1.0) - fill!(newmarg.μ, 0.0) -end - function clear!(SumM::SumM) fill!(SumM.summ, 0.0) fill!(SumM.sumμ, 0.0) end -function normupdate( - oldmess::Vector{Float64}, - newmess::Vector{Float64}) - return maximum(abs.(oldmess .- newmess)) -end - -function ρ_norm(ρ::Vector{Float64}) - return ρ ./ sum(ρ) -end - """ bethe_lattice(z::Int, tmax::Int, startfrom1::Bool) @@ -172,3 +131,26 @@ function ROC_curve(marg::Vector{Float64}, x::Vector{TI}) where {TI<:Integer} return fp_rates, tp_rates, auc end + + +function check_mess(m::Float64, μ::Float64, norm::Float64, t::Int) + if !isfinite(m) || !isfinite(μ) + println("t = $t: m = $m, μ = $μ, norm = $norm") + throw(DomainError("NaN evaluated when updating message!")) + end +end + + +function check_ρ(inode::Node{TI,TG}, ρ::FBm, M::Array{Float64,3}, t::Int, T::Int) where {TI<:InfectionModel,TG<:Union{<:AbstractGraph,Vector{<:AbstractGraph}}} + if !isfinite(ρ.fwm[1,t+1]) || !isfinite(ρ.fwm[2,t+1]) || !isfinite(ρ.bwm[1,T+1-t]) || !isfinite(ρ.bwm[2,T+1-t]) + println("node $(inode.i): fw = $(ρ.fwm), bw = $(ρ.bwm)") + throw(DomainError("NaN evaluated when computing ρ!")) + end + + if ρ.fwm[:,t+1]==[0.0,0.0] || ρ.bwm[:,T+1-t]==[0.0,0.0] + println("node $(inode.i): \n fw = $(ρ.fwm) \n bw = $(ρ.bwm)") + display(M) + println("obsprob = $(inode.obs)") + throw(DomainError("0.0 evaluated when computing ρ!")) + end +end diff --git a/test/data/margSI.jld2 b/test/data/margSI.jld2 index 98c76d6..ad9babc 100644 Binary files a/test/data/margSI.jld2 and b/test/data/margSI.jld2 differ diff --git a/test/data/margSIR.jld2 b/test/data/margSIR.jld2 index 337a6cc..e681005 100644 Binary files a/test/data/margSIR.jld2 and b/test/data/margSIR.jld2 differ diff --git a/test/data/margSIRS.jld2 b/test/data/margSIRS.jld2 index dc20bfd..6309e7f 100644 Binary files a/test/data/margSIRS.jld2 and b/test/data/margSIRS.jld2 differ diff --git a/test/data/margSIRS_timevarying.jld2 b/test/data/margSIRS_timevarying.jld2 index aa30f85..47672c7 100644 Binary files a/test/data/margSIRS_timevarying.jld2 and b/test/data/margSIRS_timevarying.jld2 differ diff --git a/test/data/margSIRSscheme.jld2 b/test/data/margSIRSscheme.jld2 index ada08d4..056b134 100644 Binary files a/test/data/margSIRSscheme.jld2 and b/test/data/margSIRSscheme.jld2 differ diff --git a/test/data/margSIRSscheme_timevarying.jld2 b/test/data/margSIRSscheme_timevarying.jld2 index 4c08a4c..30c756c 100644 Binary files a/test/data/margSIRSscheme_timevarying.jld2 and b/test/data/margSIRSscheme_timevarying.jld2 differ diff --git a/test/data/margSIR_timevarying.jld2 b/test/data/margSIR_timevarying.jld2 index e9a426e..66e1b5f 100644 Binary files a/test/data/margSIR_timevarying.jld2 and b/test/data/margSIR_timevarying.jld2 differ diff --git a/test/data/margSIRscheme.jld2 b/test/data/margSIRscheme.jld2 index 3e9053f..6cbe5b8 100644 Binary files a/test/data/margSIRscheme.jld2 and b/test/data/margSIRscheme.jld2 differ diff --git a/test/data/margSIRscheme_timevarying.jld2 b/test/data/margSIRscheme_timevarying.jld2 index 6011b3c..6956aa2 100644 Binary files a/test/data/margSIRscheme_timevarying.jld2 and b/test/data/margSIRscheme_timevarying.jld2 differ diff --git a/test/data/margSIS.jld2 b/test/data/margSIS.jld2 index 381fd4f..f5260fe 100644 Binary files a/test/data/margSIS.jld2 and b/test/data/margSIS.jld2 differ diff --git a/test/data/margSIS_timevarying.jld2 b/test/data/margSIS_timevarying.jld2 index 714c7c2..5763c0b 100644 Binary files a/test/data/margSIS_timevarying.jld2 and b/test/data/margSIS_timevarying.jld2 differ diff --git a/test/data/margSISscheme.jld2 b/test/data/margSISscheme.jld2 index f63b575..b180e7c 100644 Binary files a/test/data/margSISscheme.jld2 and b/test/data/margSISscheme.jld2 differ diff --git a/test/data/margSISscheme_timevarying.jld2 b/test/data/margSISscheme_timevarying.jld2 index 6d559d8..71d0f40 100644 Binary files a/test/data/margSISscheme_timevarying.jld2 and b/test/data/margSISscheme_timevarying.jld2 differ diff --git a/test/data/margSI_timevarying.jld2 b/test/data/margSI_timevarying.jld2 index 888d33c..177ec6f 100644 Binary files a/test/data/margSI_timevarying.jld2 and b/test/data/margSI_timevarying.jld2 differ diff --git a/test/data/margSIscheme.jld2 b/test/data/margSIscheme.jld2 index 56cc8c5..a394a2e 100644 Binary files a/test/data/margSIscheme.jld2 and b/test/data/margSIscheme.jld2 differ diff --git a/test/data/margSIscheme_timevarying.jld2 b/test/data/margSIscheme_timevarying.jld2 index 5283f3d..18fe995 100644 Binary files a/test/data/margSIscheme_timevarying.jld2 and b/test/data/margSIscheme_timevarying.jld2 differ diff --git a/test/testSI.jl b/test/testSI.jl index c766867..0d1eb54 100644 --- a/test/testSI.jl +++ b/test/testSI.jl @@ -91,4 +91,4 @@ end margtestscheme = load("data/margSIscheme.jld2", "marg") @test marg ≈ margtestscheme -end \ No newline at end of file +end diff --git a/test/testSIR.jl b/test/testSIR.jl index f3bc89c..2a7c557 100644 --- a/test/testSIR.jl +++ b/test/testSIR.jl @@ -94,4 +94,4 @@ end margtestscheme = load("data/margSIRscheme.jld2", "marg") @test marg ≈ margtestscheme -end \ No newline at end of file +end diff --git a/test/testSIRS.jl b/test/testSIRS.jl index fe78c46..d3be92a 100644 --- a/test/testSIRS.jl +++ b/test/testSIRS.jl @@ -95,4 +95,4 @@ end margtestscheme = load("data/margSIRSscheme.jld2", "marg") @test marg ≈ margtestscheme -end \ No newline at end of file +end diff --git a/test/testSIRS_timevarying.jl b/test/testSIRS_timevarying.jl index 031c8b3..21bf136 100644 --- a/test/testSIRS_timevarying.jl +++ b/test/testSIRS_timevarying.jl @@ -98,4 +98,4 @@ end margtestscheme = load("data/margSIRSscheme_timevarying.jld2", "marg") @test marg ≈ margtestscheme -end \ No newline at end of file +end diff --git a/test/testSIR_timevarying.jl b/test/testSIR_timevarying.jl index 12d9832..fff0c50 100644 --- a/test/testSIR_timevarying.jl +++ b/test/testSIR_timevarying.jl @@ -98,4 +98,4 @@ end margtestscheme = load("data/margSIRscheme_timevarying.jld2", "marg") @test marg ≈ margtestscheme -end \ No newline at end of file +end diff --git a/test/testSIS.jl b/test/testSIS.jl index 5491257..da02351 100644 --- a/test/testSIS.jl +++ b/test/testSIS.jl @@ -94,4 +94,4 @@ end margtestscheme = load("data/margSISscheme.jld2", "marg") @test marg ≈ margtestscheme -end \ No newline at end of file +end diff --git a/test/testSIS_timevarying.jl b/test/testSIS_timevarying.jl index 674d32c..775e7d8 100644 --- a/test/testSIS_timevarying.jl +++ b/test/testSIS_timevarying.jl @@ -100,4 +100,4 @@ end margtestscheme = load("data/margSISscheme_timevarying.jld2", "marg") @test marg ≈ margtestscheme -end \ No newline at end of file +end diff --git a/test/testSI_timevarying.jl b/test/testSI_timevarying.jl index a97d224..d5e02a1 100644 --- a/test/testSI_timevarying.jl +++ b/test/testSI_timevarying.jl @@ -96,4 +96,4 @@ end margtestscheme = load("data/margSIscheme_timevarying.jld2", "marg") @test marg ≈ margtestscheme -end \ No newline at end of file +end