Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
372 changes: 372 additions & 0 deletions examples/Problem Specific/Ising Model/Ising Model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,372 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Recovering a hidden temperature field from hot/cold readings\n",
"\n",
"Imagine a 2D grid of true temperatures `T[i, j]` over a surface. The sensors we have are simple: they only tell us whether a location is “hot” (1) or “cold” (0) relative to a threshold. We do not see temperatures directly. Our goal is to reconstruct the continuous temperature field from these binary observations by exploiting the fact that real temperatures vary smoothly across space.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generative model (intuition)\n",
"\n",
"- **Latent field**: Continuous temperatures `T[i, j]` that are spatially smooth; neighbors tend to be similar.\n",
"- **Observations**: Binary hot/cold readings `y[i, j] ∈ {0,1}` with noise, modeled as `y[i, j] ~ Bernoulli(σ(T[i, j]))`, where `σ` is the logistic function. The threshold can be absorbed into the offset of `T`.\n",
"- **Spatial prior**: A pairwise coupling between neighbors that penalizes sharp jumps, encouraging smooth reconstructions unless data strongly suggests boundaries.\n",
"\n",
"In this notebook, we implement the logistic observation via a custom `Sigmoid` node and enforce spatial smoothness using Gaussian couplings between neighboring cells (a Gaussian MRF–like prior).\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### What this notebook does\n",
"\n",
"- **Simulates observations**: Loads a smooth grayscale image as a proxy for `T` and produces noisy hot/cold readings via a logistic sensor.\n",
"- **Builds the model**: Combines a logistic observation (`Sigmoid`) with a spatial prior coupling neighbors.\n",
"- **Performs inference**: Uses variational message passing to approximate the posterior over `T`.\n",
"- **Visualizes results**: Compares the binary observations to the recovered continuous temperature field (normalized).\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup\n",
"\n",
"Load the packages used for probabilistic modeling, image I/O, plotting, and numerical routines.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using Distributions, ExponentialFamilyProjection, Images, Plots, ReactiveMP, RxInfer, StableRNGs, StatsFuns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Custom logistic observation factor\n",
"\n",
"We introduce a ``Sigmoid`` factor that models the logistic link and provide variational rules needed by the optimizer. This lets us couple the binary observations to the continuous latent field.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"struct Sigmoid end\n",
"\n",
"@node Sigmoid Stochastic [out, x]\n",
"\n",
"@rule Sigmoid(:x, Marginalisation) (q_out::PointMass,) = begin\n",
" y = mean(q_out)\n",
" y = float(mean(q_out))\n",
" sign = 1-2y\n",
" # Provide logpdf, gradient, and Hessian for 1D logistic-Bernoulli\n",
" _logpdf = (out, x) -> (out[] = -softplus(sign * x))\n",
" _grad = (out, x) -> (out[1] = y - logistic(x))\n",
" _hess = (out, x) -> (out[1, 1] = -logistic(x) * (1 - logistic(x)))\n",
" return ExponentialFamilyProjection.InplaceLogpdfGradHess(_logpdf, _grad, _hess)\n",
"end\n",
"\n",
"function BayesBase.prod(::GenericProd, left::UnivariateGaussianDistributionsFamily, right::ExponentialFamilyProjection.InplaceLogpdfGradHess)\n",
" m = mean(left)\n",
" σ = var(left)\n",
" combined_logpdf! = (out, x) -> begin\n",
" right.logpdf!(out, x)\n",
" out[] = logpdf(left, x) + out[]\n",
" end\n",
" combined_gradhes! = (out_grad, out_hess, x) -> begin\n",
" out_grad, out_hess = right.grad_hess!(out_grad, out_hess, x)\n",
" out_grad .= out_grad .- ((x .- m) ./ σ)\n",
" out_hess .= out_hess .- 1 / σ\n",
" return out_grad, out_hess\n",
" end\n",
" return ExponentialFamilyProjection.InplaceLogpdfGradHess(combined_logpdf!, combined_gradhes!)\n",
"end\n",
"\n",
"function BayesBase.prod(::GenericProd, left::ExponentialFamilyProjection.InplaceLogpdfGradHess, right::UnivariateGaussianDistributionsFamily)\n",
" return prod(GenericProd(), right, left)\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data: proxy temperature field and binary observations\n",
"\n",
"We load a smooth grayscale image as a stand-in for the true temperature field, normalize it, and generate noisy hot/cold readings by sampling from $$\\mathrm{Bernoulli}(\\sigma(T)).$$\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rng = StableRNG(112)\n",
"mnist_picture = load(\"mnist_picture.png\")\n",
"mnist_picture\n",
"\n",
"sample_matrix = convert(Matrix{Float64}, mnist_picture);\n",
"normalized_matrix = (sample_matrix .- mean(sample_matrix))/std(sample_matrix)\n",
"\n",
"observation_matrix = begin \n",
" o = zeros(28, 28)\n",
" for i in 1:28, j in 1:28\n",
" o[i, j] = rand(rng, Bernoulli(logistic(normalized_matrix[i, j])))\n",
" end\n",
" o\n",
"end\n",
"\n",
"Gray.(observation_matrix)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model and first inference run\n",
"\n",
"- The model places a Gaussian prior with neighbor couplings on the latent field $$x[i,j]$$ and uses a logistic observation $$y \\sim \\mathrm{Bernoulli}(\\sigma(x))$$ via the custom ``Sigmoid`` factor.\n",
"- We run variational inference and visualize three panels: normalized proxy field, binary observations, and the reconstructed field.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@model function sigmoid_ising(h, w, image)\n",
" # x_extra_prior ~ NormalMeanVariance(0, 1)\n",
" local x\n",
" connection_force = 1\n",
" prior ~ NormalMeanVariance(0, connection_force)\n",
" for i in 1:h, j in 1:w\n",
" x[i, j] ~ NormalMeanVariance(prior, connection_force)\n",
" end \n",
" for i in 1:h, j in 1:w\n",
" image[i, j] ~ Sigmoid(x[i, j]) \n",
" if i < h && j < w\n",
" x[i, j] ~ NormalMeanVariance(x[i+1, j], connection_force)\n",
" x[i, j] ~ NormalMeanVariance(x[i, j+1], connection_force)\n",
" end\n",
" if i < h\n",
" x[i, j] ~ NormalMeanVariance(x[i+1, j], connection_force)\n",
" end\n",
" if j < w\n",
" x[i, j] ~ NormalMeanVariance(x[i, j+1], connection_force)\n",
" end\n",
" end\n",
"end\n",
"\n",
"# Streaming init & autoupdates\n",
"sigmoid_init = @initialization begin\n",
" q(x) = NormalMeanVariance(0.0, 1.0)\n",
" q(prior) = NormalMeanVariance(0.5, 1)\n",
"end\n",
"\n",
"binary_constraints = @constraints begin\n",
" q(x) :: ProjectedTo(NormalMeanVariance, parameters = ProjectionParameters(\n",
" tolerance = 1e-8,\n",
" strategy = ExponentialFamilyProjection.GaussNewton(nsamples = 1), # deterministic\n",
" ))\n",
" q(x, prior) = q(x)q(prior)\n",
" q(x) = MeanField()\n",
" # q(x_prior, x, x_extra, x_extra_prior) = q(x_prior)q(x)q(x_extra, x_extra_prior)\n",
"end\n",
"\n",
"result = infer(\n",
" model = sigmoid_ising(h=28, w=28), \n",
" data = (image = observation_matrix,),\n",
" returnvars = KeepEach(),\n",
" # options = (limit_stack_depth = 100, ),\n",
" iterations = 5,\n",
" initialization = sigmoid_init,\n",
" constraints = binary_constraints,\n",
" showprogress = true\n",
");\n",
"\n",
"sigmoid_outputs = map(mean, result.posteriors[:x][5]);\n",
"normalize_sigmoid_outputs = (sigmoid_outputs .- mean(sigmoid_outputs))/std(sigmoid_outputs)\n",
"\n",
"l = @layout [\n",
" grid(1,3)\n",
"]\n",
"plot_obj = plot(layout=l)\n",
"plot!(plot_obj, Gray.(normalized_matrix), subplot=1, legend=false, framestyle=:none, ticks=nothing, aspect_ratio=:equal)\n",
"plot!(plot_obj, Gray.(observation_matrix), subplot=2, legend=false, framestyle=:none, ticks=nothing, aspect_ratio=:equal)\n",
"plot!(plot_obj, Gray.(normalize_sigmoid_outputs), subplot=3, legend=false, framestyle=:none, ticks=nothing, aspect_ratio=:equal)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Introduce missing observations\n",
"\n",
"We simulate missing data by randomly masking a fraction of binary readings. Masked locations will be rendered in yellow in the visualization.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mask_probability = 0.25\n",
"masked_pattern = rand(rng, Bernoulli(mask_probability), size(observation_matrix)...) \n",
"\n",
"# Apply mask to observations as Union{Missing, Float64}\n",
"masked_observation_matrix = Matrix{Union{Missing, Float64}}(undef, size(observation_matrix)...)\n",
"@inbounds for j in axes(observation_matrix, 2), i in axes(observation_matrix, 1)\n",
" masked_observation_matrix[i, j] = masked_pattern[i, j] ? missing : Float64(observation_matrix[i, j])\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Missing binary observations and the Monte Carlo message\n",
"\n",
"When some binary outputs y[i,j] in {0,1} are missing, we can still perform inference by passing an approximate message from the observation factor using the variational marginal over the latent field.\n",
"\n",
"For the logistic observation model `y | x ~ Bernoulli(σ(x))`, the factor contribution is\n",
"$$ f(y \\mid x) = \\mathrm{Bernoulli}(y; \\sigma(x)) = \\sigma(x)^y (1-\\sigma(x))^{1-y}. $$\n",
"The message needed by variational updates in many formulations is the expected log-factor under the current marginal `q(x)`:\n",
"$$ \\mathbb{E}_{q(x)}[\\log f(y \\mid x)] = y\\,\\mathbb{E}_{q(x)}[\\log \\sigma(x)] + (1-y)\\,\\mathbb{E}_{q(x)}[\\log(1-\\sigma(x))] = \\mu y + C. $$\n",
"Note that $$ \\log \\sigma(x) - \\log(1-\\sigma(x)) = x.$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Implement observation message for missing outputs\n",
"\n",
"We add a simple rule for ``q(y)`` when needed: use the current mean of ``q(x)`` passed through the logistic to parameterize a Bernoulli. This provides a lightweight, consistent message for the missing-observation case.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@rule Sigmoid(:out, Marginalisation) (q_x::NormalMeanVariance, ) = begin\n",
" return Bernoulli(logistic(mean(q_x)))\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Inference with missing observations\n",
"\n",
"We now run the same model on the masked data. The observation message (previous cell) lets inference proceed for locations where ``y`` is missing.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result_masked = infer(\n",
" model = sigmoid_ising(h=28, w=28), \n",
" data = (image = masked_observation_matrix,),\n",
" returnvars = KeepEach(),\n",
" # options = (limit_stack_depth = 100, ),\n",
" iterations = 5,\n",
" initialization = sigmoid_init,\n",
" constraints = binary_constraints,\n",
" showprogress = true\n",
");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualize masked observations and reconstruction\n",
"\n",
"- Yellow pixels mark missing observations; gray pixels show observed hot/cold readings rendered as grayscale for context.\n",
"- We compare the original normalized proxy field, the binary observations, the unmasked reconstruction, the masked observation map, and the masked reconstruction.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sigmoid_outputs_masked = map(mean, result_masked.posteriors[:x][5]);\n",
"normalize_masked_sigmoid_outputs = (sigmoid_outputs_masked .- mean(sigmoid_outputs_masked))/std(sigmoid_outputs_masked)\n",
"\n",
"yellow = colorant\"yellow\"# Replace missings with 0 just to build the base gray image in RGB\n",
"masked_img = RGB.(Gray.(replace(masked_observation_matrix, missing => 0.0)))\n",
"masked_img[masked_pattern] .= yellow\n",
"\n",
"\n",
"plot_obj_masked = plot(layout=@layout [\n",
" grid(1,5)\n",
"])\n",
"plot!(plot_obj_masked, Gray.(normalized_matrix), subplot=1, legend=false, framestyle=:none, ticks=nothing, aspect_ratio=:equal)\n",
"plot!(plot_obj_masked, Gray.(observation_matrix), subplot=2, legend=false, framestyle=:none, ticks=nothing, aspect_ratio=:equal)\n",
"plot!(plot_obj_masked, Gray.(normalize_sigmoid_outputs), subplot=3, legend=false, framestyle=:none, ticks=nothing, aspect_ratio=:equal)\n",
"plot!(plot_obj_masked, masked_img, subplot=4, legend=false, framestyle=:none, ticks=nothing, aspect_ratio=:equal)\n",
"plot!(plot_obj_masked, Gray.(normalize_masked_sigmoid_outputs), subplot=5, legend=false, framestyle=:none, ticks=nothing, aspect_ratio=:equal)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Conclusion\n",
"\n",
"- We reconstructed a smooth latent temperature field from binary hot/cold readings using a logistic observation model and a spatial (neighbor) prior.\n",
"- With missing observations, adding an approximate observation message allows inference to proceed; the reconstruction remains coherent where data is absent.\n",
"- Try adjusting the coupling strength (``connection_force``), the number of iterations, or the mask probability to see the effect on smoothness and detail.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.11.7",
"language": "julia",
"name": "julia-1.11"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
9 changes: 9 additions & 0 deletions examples/Problem Specific/Ising Model/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ReactiveMP = "a194aa59-28ba-4574-a09c-4a745416d6e3"
RxInfer = "86711068-29c9-4ff7-b620-ae75d7495b3d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
7 changes: 7 additions & 0 deletions examples/Problem Specific/Ising Model/meta.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
return (
title = "Ising Model",
description = """
Using Bayesian Inference and RxInfer to temperature over the grid with binary observations.
""",
tags = ["problem specific", "grid modeling", "interaction modeling"]
)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading