Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conflict between plum-dispatch and cola-plum-dispatch #441

Open
vabor112 opened this issue Mar 8, 2024 · 5 comments
Open

Conflict between plum-dispatch and cola-plum-dispatch #441

vabor112 opened this issue Mar 8, 2024 · 5 comments
Labels
bug Something isn't working no-stale

Comments

@vabor112
Copy link

vabor112 commented Mar 8, 2024

Bug Report

GPJax version: 0.8.0

Current behavior:

I am trying to update the existing integration of GeometricKernels with GPJax so that it works with newer versions of GPJax. It works okay for GPJax 0.6.9. However, for the current GPJax 0.8.0, I hit two problems.

The first one is exactly #397, which, although quite annoying, can be fixed by downgrading tensorflow to version 2.13.

The second one is illustrated in the Related code section below. I believe it is concenred with plum-dispatch, which we use extensively in GeometricKernels to support multiple backends. GPJax uses cola which in its turn relies on a fork of cola, cola-plum-dispatch. This unmaintained fork uses the same namespace plum (which seems like a terrible sin) and gets overriden by the actual plum that GeometricKernels uses, causing the error below. I believe this is similar to this issue.

Expected behavior:

I am not sure how to fix this, but it seems to be an important problem to fix as otherwise GPJax becomes incompatible with any other libraries that rely on plum-dispatch, which is quite popular.

Steps to reproduce:

See below.

Related code:

It is enough to run this snippet:

# Import a backend, we use jax in this example.
import jax.numpy as jnp
import jax
import gpjax as gpx

# Import the geometric_kernels backend.
import geometric_kernels
import geometric_kernels.jax

which leads to

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 4
      2 import jax.numpy as jnp
      3 import jax
----> 4 import gpjax as gpx
      6 # Import the geometric_kernels backend.
      7 import geometric_kernels

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/__init__.py:15
      1 # Copyright 2022 The GPJax Contributors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
---> 15 from gpjax import (
     16     base,
     17     decision_making,
     18     gps,
     19     integrators,
     20     kernels,
     21     likelihoods,
     22     mean_functions,
     23     objectives,
     24     variational_families,
     25 )
     26 from gpjax.base import (
     27     Module,
     28     param_field,
     29 )
     30 from gpjax.citation import cite

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/__init__.py:15
      1 # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
---> 15 from gpjax.decision_making.decision_maker import (
     16     AbstractDecisionMaker,
     17     UtilityDrivenDecisionMaker,
     18 )
     19 from gpjax.decision_making.posterior_handler import PosteriorHandler
     20 from gpjax.decision_making.search_space import (
     21     AbstractSearchSpace,
     22     ContinuousSearchSpace,
     23 )

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/decision_maker.py:32
     29 import jax.random as jr
     31 from gpjax.dataset import Dataset
---> 32 from gpjax.decision_making.posterior_handler import PosteriorHandler
     33 from gpjax.decision_making.search_space import AbstractSearchSpace
     34 from gpjax.decision_making.utility_functions import (
     35     AbstractUtilityFunctionBuilder,
     36     ThompsonSampling,
     37 )

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/posterior_handler.py:25
     23 import gpjax as gpx
     24 from gpjax.dataset import Dataset
---> 25 from gpjax.gps import (
     26     AbstractLikelihood,
     27     AbstractPosterior,
     28     AbstractPrior,
     29 )
     30 from gpjax.objectives import AbstractObjective
     31 from gpjax.typing import KeyArray

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/gps.py:26
     18 from typing import overload
     20 from beartype.typing import (
     21     Any,
     22     Callable,
     23     Optional,
     24     Union,
     25 )
---> 26 import cola
     27 from cola.linalg.decompositions.decompositions import Cholesky
     28 import jax.numpy as jnp

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/__init__.py:11
      9 __all__ = []
     10 # for loader, module_name, is_pkg in  pkgutil.walk_packages(__path__):
---> 11 import_from_all("fns", globals(), __all__, __name__)
     12 import_from_all("annotations", globals(), __all__, __name__)
     13 import_from_all("linalg", globals(), __all__, __name__)

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/utils/__init__.py:36, in import_from_all(module_name, namespace, _all, _name)
     32 def import_from_all(module_name, namespace, _all, _name):
     33     """Import all functions from module.__all__ into the namespace and add to __all__.
     34     example usage: import_every("operators",globals(),__all__,__name__)
     35     """
---> 36     module = importlib.import_module('.' + module_name, package=_name)
     37     if not hasattr(module, "__all__"):
     38         logging.debug(f"empty {module_name}.__all__")

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/importlib/__init__.py:126, in import_module(name, package)
    124             break
    125         level += 1
--> 126 return _bootstrap._gcd_import(name[level:], package, level)

File ~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/fns.py:127
    122 @dispatch
    123 def transpose(A: Dense):
    124     return Dense(A.A.T)
--> 127 @dispatch(cond=lambda A: A.isa(cola.SelfAdjoint))
    128 def transpose(A: LinearOperator):
    129     # dangerous, TODO: fix when A is complex or unify transpose and adjoint
    130     return A
    133 @dispatch
    134 def transpose(A: Triangular):

TypeError: Dispatcher.__call__() got an unexpected keyword argument 'cond'
@vabor112 vabor112 added the bug Something isn't working label Mar 8, 2024
Copy link

github-actions bot commented Sep 1, 2024

This issue has been marked as stale because it has been open for 7 days with no activity.

@aterenin
Copy link

aterenin commented Sep 2, 2024

Spoke to @daniel-dodd - this issue is caused by the downstream dependency cola, which GPJax relies on. An issue should be filed there.

@aterenin
Copy link

aterenin commented Sep 2, 2024

CC: @vabor112

@thomaspinder
Copy link
Collaborator

Thanks for updating @aterenin. I'd prefer to see this fixed upstream in Cola, otherwise we may need to fork the project and implement a workaround. Needless to say, this would be messy.

@vabor112
Copy link
Author

vabor112 commented Sep 3, 2024

I filed an issue with cola developers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working no-stale
Projects
None yet
Development

No branches or pull requests

3 participants