Skip to content

Commit

Permalink
Fix masking for density channel
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Sep 12, 2022
1 parent b8009a9 commit 7aeb4f6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions neuralprocesses/coders/setconv/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def code(coder: PrependDensityChannel, xz, z: B.Numeric, x, **kw_args):
@_dispatch
def code(coder: PrependDensityChannel, xz, z: Masked, x, **kw_args):
mask = z.mask
# Set the missing values to zero. Zeros in the data channel do not affect the
# encoding.
z = z.y * mask
return code(coder, xz, z, x, **kw_args)
d = data_dims(xz)
# Set the missing values to zero by multiplying with the mask. Zeros in the data
# channel do not affect the encoding.
return xz, B.concat(z.mask, z.y * z.mask, axis=-d - 1)


@register_module
Expand Down

0 comments on commit 7aeb4f6

Please sign in to comment.