From 7aeb4f6fbb87c32ec75f86af2d516d03a970e5d7 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Mon, 12 Sep 2022 18:11:10 +0100 Subject: [PATCH] Fix masking for density channel --- neuralprocesses/coders/setconv/density.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/neuralprocesses/coders/setconv/density.py b/neuralprocesses/coders/setconv/density.py index e01b2c3c..87c8d0d3 100644 --- a/neuralprocesses/coders/setconv/density.py +++ b/neuralprocesses/coders/setconv/density.py @@ -58,10 +58,10 @@ def code(coder: PrependDensityChannel, xz, z: B.Numeric, x, **kw_args): @_dispatch def code(coder: PrependDensityChannel, xz, z: Masked, x, **kw_args): mask = z.mask - # Set the missing values to zero. Zeros in the data channel do not affect the - # encoding. - z = z.y * mask - return code(coder, xz, z, x, **kw_args) + d = data_dims(xz) + # Set the missing values to zero by multiplying with the mask. Zeros in the data + # channel do not affect the encoding. + return xz, B.concat(z.mask, z.y * z.mask, axis=-d - 1) @register_module