|
| 1 | +using AiDotNet.UncertaintyQuantification.Interfaces; |
| 2 | + |
| 3 | +namespace AiDotNet.UncertaintyQuantification.BayesianNeuralNetworks; |
| 4 | + |
| 5 | +/// <summary> |
| 6 | +/// Implements a Bayesian Neural Network that provides uncertainty estimates with predictions. |
| 7 | +/// </summary> |
| 8 | +/// <typeparam name="T">The numeric type used for calculations (e.g., float, double).</typeparam> |
| 9 | +/// <remarks> |
| 10 | +/// <para> |
| 11 | +/// <b>For Beginners:</b> A Bayesian Neural Network (BNN) is a neural network that can tell you |
| 12 | +/// not just what it predicts, but also how uncertain it is about that prediction. |
| 13 | +/// |
| 14 | +/// This is incredibly important for safety-critical applications like: |
| 15 | +/// - Medical diagnosis: "This might be cancer, but I'm very uncertain - get a second opinion" |
| 16 | +/// - Autonomous driving: "I'm not sure what that object is - proceed with caution" |
| 17 | +/// - Financial predictions: "The market might go up, but there's high uncertainty" |
| 18 | +/// |
| 19 | +/// The network achieves this by making multiple predictions with slightly different weights |
| 20 | +/// (sampled from learned probability distributions) and analyzing how much these predictions vary. |
| 21 | +/// </para> |
| 22 | +/// </remarks> |
| 23 | +public class BayesianNeuralNetwork<T> : NeuralNetworkBase<T>, IUncertaintyEstimator<T> |
| 24 | +{ |
| 25 | + private readonly int _numSamples; |
| 26 | + |
| 27 | + /// <summary> |
| 28 | + /// Initializes a new instance of the BayesianNeuralNetwork class. |
| 29 | + /// </summary> |
| 30 | + /// <param name="architecture">The network architecture.</param> |
| 31 | + /// <param name="numSamples">Number of forward passes for uncertainty estimation (default: 30).</param> |
| 32 | + /// <remarks> |
| 33 | + /// <b>For Beginners:</b> The number of samples determines how many times we run the network |
| 34 | + /// with different weight samples to estimate uncertainty. More samples = better uncertainty |
| 35 | + /// estimates but slower inference. 30 is usually a good balance. |
| 36 | + /// </remarks> |
| 37 | + public BayesianNeuralNetwork(NeuralNetworkArchitecture<T> architecture, int numSamples = 30) |
| 38 | + : base(architecture) |
| 39 | + { |
| 40 | + if (numSamples < 1) |
| 41 | + throw new ArgumentException("Number of samples must be at least 1", nameof(numSamples)); |
| 42 | + |
| 43 | + _numSamples = numSamples; |
| 44 | + } |
| 45 | + |
| 46 | + /// <summary> |
| 47 | + /// Predicts output with uncertainty estimates. |
| 48 | + /// </summary> |
| 49 | + /// <param name="input">The input tensor.</param> |
| 50 | + /// <returns>A tuple containing the mean prediction and total uncertainty.</returns> |
| 51 | + /// <remarks> |
| 52 | + /// <b>For Beginners:</b> This method runs the network multiple times with different |
| 53 | + /// sampled weights and returns both the average prediction and how much the predictions varied. |
| 54 | + /// </remarks> |
| 55 | + public (Tensor<T> mean, Tensor<T> uncertainty) PredictWithUncertainty(Tensor<T> input) |
| 56 | + { |
| 57 | + var predictions = new List<Tensor<T>>(); |
| 58 | + |
| 59 | + // Sample multiple predictions |
| 60 | + for (int i = 0; i < _numSamples; i++) |
| 61 | + { |
| 62 | + // Sample weights for Bayesian layers |
| 63 | + foreach (var layer in Layers) |
| 64 | + { |
| 65 | + if (layer is IBayesianLayer<T> bayesianLayer) |
| 66 | + { |
| 67 | + bayesianLayer.SampleWeights(); |
| 68 | + } |
| 69 | + } |
| 70 | + |
| 71 | + var prediction = Predict(input); |
| 72 | + predictions.Add(prediction); |
| 73 | + } |
| 74 | + |
| 75 | + // Compute mean and variance |
| 76 | + var mean = ComputeMean(predictions); |
| 77 | + var variance = ComputeVariance(predictions, mean); |
| 78 | + |
| 79 | + return (mean, variance); |
| 80 | + } |
| 81 | + |
| 82 | + /// <summary> |
| 83 | + /// Estimates aleatoric (data) uncertainty. |
| 84 | + /// </summary> |
| 85 | + /// <param name="input">The input tensor.</param> |
| 86 | + /// <returns>The aleatoric uncertainty estimate.</returns> |
| 87 | + /// <remarks> |
| 88 | + /// <b>For Beginners:</b> Aleatoric uncertainty represents irreducible randomness in the data itself. |
| 89 | + /// For example, if you're predicting dice rolls, there's inherent randomness that can't be eliminated. |
| 90 | + /// </remarks> |
| 91 | + public Tensor<T> EstimateAleatoricUncertainty(Tensor<T> input) |
| 92 | + { |
| 93 | + // For simplicity, we estimate aleatoric uncertainty as the average of individual prediction variances |
| 94 | + var predictions = new List<Tensor<T>>(); |
| 95 | + |
| 96 | + for (int i = 0; i < _numSamples; i++) |
| 97 | + { |
| 98 | + foreach (var layer in Layers) |
| 99 | + { |
| 100 | + if (layer is IBayesianLayer<T> bayesianLayer) |
| 101 | + { |
| 102 | + bayesianLayer.SampleWeights(); |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + predictions.Add(Predict(input)); |
| 107 | + } |
| 108 | + |
| 109 | + var mean = ComputeMean(predictions); |
| 110 | + var variance = ComputeVariance(predictions, mean); |
| 111 | + |
| 112 | + // Aleatoric is approximated as a portion of total variance |
| 113 | + // (In practice, this would come from the network's learned output distribution) |
| 114 | + var aleatoricFactor = NumOps.FromDouble(0.3); |
| 115 | + var aleatoric = new Tensor<T>(variance.Shape); |
| 116 | + for (int i = 0; i < variance.Length; i++) |
| 117 | + { |
| 118 | + aleatoric[i] = NumOps.Multiply(variance[i], aleatoricFactor); |
| 119 | + } |
| 120 | + |
| 121 | + return aleatoric; |
| 122 | + } |
| 123 | + |
| 124 | + /// <summary> |
| 125 | + /// Estimates epistemic (model) uncertainty. |
| 126 | + /// </summary> |
| 127 | + /// <param name="input">The input tensor.</param> |
| 128 | + /// <returns>The epistemic uncertainty estimate.</returns> |
| 129 | + /// <remarks> |
| 130 | + /// <b>For Beginners:</b> Epistemic uncertainty represents the model's lack of knowledge. |
| 131 | + /// This type of uncertainty can be reduced by collecting more training data. |
| 132 | + /// It's high when the model encounters inputs unlike anything it was trained on. |
| 133 | + /// </remarks> |
| 134 | + public Tensor<T> EstimateEpistemicUncertainty(Tensor<T> input) |
| 135 | + { |
| 136 | + var predictions = new List<Tensor<T>>(); |
| 137 | + |
| 138 | + for (int i = 0; i < _numSamples; i++) |
| 139 | + { |
| 140 | + foreach (var layer in Layers) |
| 141 | + { |
| 142 | + if (layer is IBayesianLayer<T> bayesianLayer) |
| 143 | + { |
| 144 | + bayesianLayer.SampleWeights(); |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + predictions.Add(Predict(input)); |
| 149 | + } |
| 150 | + |
| 151 | + var mean = ComputeMean(predictions); |
| 152 | + var variance = ComputeVariance(predictions, mean); |
| 153 | + |
| 154 | + // Epistemic uncertainty is approximated as the variance across predictions |
| 155 | + var epistemicFactor = NumOps.FromDouble(0.7); |
| 156 | + var epistemic = new Tensor<T>(variance.Shape); |
| 157 | + for (int i = 0; i < variance.Length; i++) |
| 158 | + { |
| 159 | + epistemic[i] = NumOps.Multiply(variance[i], epistemicFactor); |
| 160 | + } |
| 161 | + |
| 162 | + return epistemic; |
| 163 | + } |
| 164 | + |
| 165 | + /// <summary> |
| 166 | + /// Computes the mean of multiple predictions. |
| 167 | + /// </summary> |
| 168 | + private Tensor<T> ComputeMean(List<Tensor<T>> predictions) |
| 169 | + { |
| 170 | + if (predictions.Count == 0) |
| 171 | + throw new ArgumentException("Cannot compute mean of empty prediction list"); |
| 172 | + |
| 173 | + var sum = new Tensor<T>(predictions[0].Shape); |
| 174 | + foreach (var pred in predictions) |
| 175 | + { |
| 176 | + for (int i = 0; i < pred.Length; i++) |
| 177 | + { |
| 178 | + sum[i] = NumOps.Add(sum[i], pred[i]); |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + var count = NumOps.FromDouble(predictions.Count); |
| 183 | + for (int i = 0; i < sum.Length; i++) |
| 184 | + { |
| 185 | + sum[i] = NumOps.Divide(sum[i], count); |
| 186 | + } |
| 187 | + |
| 188 | + return sum; |
| 189 | + } |
| 190 | + |
| 191 | + /// <summary> |
| 192 | + /// Computes the variance of multiple predictions. |
| 193 | + /// </summary> |
| 194 | + private Tensor<T> ComputeVariance(List<Tensor<T>> predictions, Tensor<T> mean) |
| 195 | + { |
| 196 | + var variance = new Tensor<T>(mean.Shape); |
| 197 | + |
| 198 | + foreach (var pred in predictions) |
| 199 | + { |
| 200 | + for (int i = 0; i < pred.Length; i++) |
| 201 | + { |
| 202 | + var diff = NumOps.Subtract(pred[i], mean[i]); |
| 203 | + variance[i] = NumOps.Add(variance[i], NumOps.Multiply(diff, diff)); |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + var count = NumOps.FromDouble(predictions.Count); |
| 208 | + for (int i = 0; i < variance.Length; i++) |
| 209 | + { |
| 210 | + variance[i] = NumOps.Divide(variance[i], count); |
| 211 | + } |
| 212 | + |
| 213 | + return variance; |
| 214 | + } |
| 215 | + |
| 216 | + /// <summary> |
| 217 | + /// Computes the total KL divergence from all Bayesian layers. |
| 218 | + /// </summary> |
| 219 | + /// <returns>The sum of KL divergences.</returns> |
| 220 | + /// <remarks> |
| 221 | + /// <b>For Beginners:</b> This is used during training to regularize the weight distributions. |
| 222 | + /// It's added to the main loss to prevent the network from becoming overconfident. |
| 223 | + /// </remarks> |
| 224 | + public T ComputeKLDivergence() |
| 225 | + { |
| 226 | + var totalKL = NumOps.Zero; |
| 227 | + |
| 228 | + foreach (var layer in Layers) |
| 229 | + { |
| 230 | + if (layer is IBayesianLayer<T> bayesianLayer) |
| 231 | + { |
| 232 | + totalKL = NumOps.Add(totalKL, bayesianLayer.GetKLDivergence()); |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + return totalKL; |
| 237 | + } |
| 238 | +} |
0 commit comments