diff --git a/neuralprocesses/model/elbo.py b/neuralprocesses/model/elbo.py index 353b7e5e..4bbc0795 100644 --- a/neuralprocesses/model/elbo.py +++ b/neuralprocesses/model/elbo.py @@ -98,7 +98,7 @@ def elbo( if normalise: # Normalise by the number of targets. - elbos = elbos / num_data(xt, yt) + elbos = elbos / B.cast(float64, num_data(xt, yt)) return state, elbos diff --git a/neuralprocesses/model/loglik.py b/neuralprocesses/model/loglik.py index 9520353d..0d4f6fa3 100644 --- a/neuralprocesses/model/loglik.py +++ b/neuralprocesses/model/loglik.py @@ -89,7 +89,7 @@ def loglik( if normalise: # Normalise by the number of targets. - logpdfs = logpdfs / num_data(xt, yt) + logpdfs = logpdfs / B.cast(float64, num_data(xt, yt)) return state, logpdfs diff --git a/tests/test_architectures.py b/tests/test_architectures.py index c0d185f6..290005d8 100644 --- a/tests/test_architectures.py +++ b/tests/test_architectures.py @@ -249,23 +249,25 @@ def test_forward(nps, model_sample): check_prediction(nps, pred, yt) +@pytest.mark.parametrize("normalise", [False, True]) @pytest.mark.flaky(reruns=3) -def test_elbo(nps, model_sample): +def test_elbo(nps, model_sample, normalise): model, sample = model_sample model = model() xc, yc, xt, yt = sample() - elbos = nps.elbo(model, xc, yc, xt, yt, num_samples=2) + elbos = nps.elbo(model, xc, yc, xt, yt, num_samples=2, normalise=normalise) assert B.rank(elbos) == 1 assert np.isfinite(B.to_numpy(B.sum(elbos))) assert B.dtype(elbos) == nps.dtype64 +@pytest.mark.parametrize("normalise", [False, True]) @pytest.mark.flaky(reruns=3) -def test_loglik(nps, model_sample): +def test_loglik(nps, model_sample, normalise): model, sample = model_sample model = model() xc, yc, xt, yt = sample() - logpdfs = nps.loglik(model, xc, yc, xt, yt, num_samples=2) + logpdfs = nps.loglik(model, xc, yc, xt, yt, num_samples=2, normalise=normalise) assert B.rank(logpdfs) == 1 assert np.isfinite(B.to_numpy(B.sum(logpdfs))) assert B.dtype(logpdfs) == nps.dtype64