Skip to content

Commit

Permalink
Plum 2 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Mar 6, 2023
1 parent e84a1e5 commit e421d6b
Show file tree
Hide file tree
Showing 16 changed files with 50 additions and 30 deletions.
3 changes: 2 additions & 1 deletion neuralprocesses/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import lab as B
from plum import Tuple

from . import _dispatch

Expand Down
10 changes: 6 additions & 4 deletions neuralprocesses/coders/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import Union

import lab as B
import numpy as np
from plum import convert, Union
from plum import convert

from .. import _dispatch
from ..aggregate import Aggregate, AggregateInput
from ..datadims import data_dims
from ..util import (
register_module,
merge_dimensions,
register_composite_coder,
register_module,
select,
split,
split_dimension,
merge_dimensions,
select,
)

__all__ = [
Expand Down
3 changes: 2 additions & 1 deletion neuralprocesses/coders/setconv/density.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional, Union

import lab as B
from plum import isinstance

from ... import _dispatch
from ...datadims import data_dims
from ...mask import Masked
from ...parallel import broadcast_coder_over_parallel
from ...util import register_module, batch
from ...util import batch, register_module

__all__ = [
"PrependDensityChannel",
Expand Down
15 changes: 10 additions & 5 deletions neuralprocesses/coding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import matrix # noqa
from plum import ptype, Signature
from plum import isinstance, issubclass

from . import _dispatch
from .dist import Dirac, AbstractMultiOutputDistribution
from .dist import AbstractMultiOutputDistribution, Dirac
from .parallel import Parallel
from .util import is_composite_coder

Expand Down Expand Up @@ -30,12 +30,17 @@ def code(coder, xz, z, x, **kw_args):
tuple[input, tensor]: New encoding.
"""
if any(
[ptype(type(coder)) <= s.base[0] < ptype(object) for s in code.methods.keys()]
[
isinstance(coder, s.types[0])
and issubclass(s.types[0], object)
and not issubclass(object, s.types[0])
for s in code.methods
]
):
raise RuntimeError(
f"Dispatched to fallback implementation for `code`, but specialised "
f"implementation are available. (The signature of the arguments is "
f"{Signature(type(coder), type(xz), type(z), type(x))}.)"
f"implementation are available. The arguments are "
f"`({coder}, {xz}, {z}, {x})`."
)
return xz, coder(z)

Expand Down
7 changes: 3 additions & 4 deletions neuralprocesses/disc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Optional
from typing import List, Optional

import lab as B
from lab.shape import Dimension
from plum import List

from . import _dispatch
from .augment import AugmentedInput
from .aggregate import AggregateInput
from .augment import AugmentedInput
from .parallel import Parallel
from .util import register_module, is_nonempty, batch
from .util import batch, is_nonempty, register_module

__all__ = ["Discretisation"]

Expand Down
5 changes: 3 additions & 2 deletions neuralprocesses/dist/geom.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Union

import lab as B
import numpy as np
from plum import Union

from .dist import AbstractDistribution
from .. import _dispatch
from .dist import AbstractDistribution

__all__ = ["TruncatedGeometric"]

Expand Down
6 changes: 4 additions & 2 deletions neuralprocesses/dist/normal.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Union

import lab as B
import numpy as np
from matrix import AbstractMatrix, Dense, Diagonal, LowRank, Woodbury
from plum import parametric, Union
from plum import parametric
from stheno import Normal
from wbml.util import indented_kv

from .dist import AbstractMultiOutputDistribution
from .. import _dispatch
from ..aggregate import Aggregate
from ..util import batch, split
from .dist import AbstractMultiOutputDistribution

__all__ = ["MultiOutputNormal"]

Expand Down
5 changes: 3 additions & 2 deletions neuralprocesses/dist/uniform.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Tuple

import lab as B
from lab.shape import Dimension
from plum import Tuple

from .dist import AbstractDistribution
from .. import _dispatch
from .dist import AbstractDistribution

__all__ = ["UniformContinuous", "UniformDiscrete"]

Expand Down
3 changes: 2 additions & 1 deletion neuralprocesses/mask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Tuple, Union

import lab as B
from lab.util import resolve_axis
from plum import Tuple, Union

from . import _dispatch

Expand Down
4 changes: 3 additions & 1 deletion neuralprocesses/model/ar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Union

import lab as B
import numpy as np
from plum import Dispatcher, Union
from plum import Dispatcher
from wbml.util import inv_perm

from .. import _dispatch
Expand Down
5 changes: 3 additions & 2 deletions neuralprocesses/model/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import List, Tuple, Union

import lab as B
from matrix.util import indent
from plum import List, Tuple, Union

from .util import sample, compress_contexts
from .. import _dispatch
from ..augment import AugmentedInput
from ..coding import code
from ..mask import Masked
from ..util import register_module
from .util import compress_contexts, sample

__all__ = ["Model"]

Expand Down
3 changes: 2 additions & 1 deletion neuralprocesses/model/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union

import lab as B
from matrix import Diagonal
from plum import Union

from .. import _dispatch
from ..aggregate import Aggregate, AggregateInput
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"numpy>=1.16",
"backends>=1.4.27",
"backends-matrix>=1.2.10",
"plum-dispatch>=1",
"plum-dispatch>=2",
"stheno>=1.3.10",
"wbml>=0.3.18",
]
Expand Down
4 changes: 3 additions & 1 deletion tests/test_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import lab as B
import numpy as np
import pytest
from plum import isinstance

from .util import nps as nps_fixed_dtype, approx, generate_data # noqa
from .util import approx, generate_data
from .util import nps as nps_fixed_dtype # noqa


def generate_conv_arch_variations(configs):
Expand Down
1 change: 1 addition & 0 deletions tests/test_unet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import lab as B
import numpy as np
from plum import isinstance

from .util import nps # noqa

Expand Down
4 changes: 2 additions & 2 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import socket
from typing import Union

import socket
import lab as B
import neuralprocesses
import pytest
Expand Down Expand Up @@ -63,7 +63,7 @@ def generate_data(nps, batch_size=4, dim_x=1, dim_y=1, n_context=5, n_target=7):
return xc, yc, xt, yt


if socket.gethostname() == "Wessels-Crib":
if socket.gethostname().lower().startswith("wessel"):
remote_xfail = lambda f: f #: `xfail` only on CI.
remote_skip = lambda f: f #: `skip` only on CI.
else:
Expand Down

0 comments on commit e421d6b

Please sign in to comment.